diff --git a/3rdparty/mkldnn b/3rdparty/mkldnn index 9008d8ab096a..a0a87d662ede 160000 --- a/3rdparty/mkldnn +++ b/3rdparty/mkldnn @@ -1 +1 @@ -Subproject commit 9008d8ab096ae29f158840231ff431aea8bf3467 +Subproject commit a0a87d662edeef38d01db4ac5dd25f59a1f0881f diff --git a/CMakeLists.txt b/CMakeLists.txt index a06aa9dba485..0eba24f61d14 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,8 +29,7 @@ mxnet_option(USE_SSE "Build with x86 SSE instruction support" ON IF mxnet_option(USE_F16C "Build with x86 F16C instruction support" ON) # autodetects support if ON mxnet_option(USE_LAPACK "Build with lapack support" ON) mxnet_option(USE_MKL_IF_AVAILABLE "Use MKL if found" ON) -mxnet_option(USE_MKLML_MKL "Use MKLDNN variant of MKL (if MKL found)" ON IF USE_MKL_IF_AVAILABLE AND (NOT APPLE) AND (NOT MSVC) ) -mxnet_option(USE_MKLDNN "Use MKLDNN variant of MKL (if MKL found)" ON IF USE_MKL_IF_AVAILABLE AND (NOT APPLE) AND (NOT MSVC) AND (CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86_64") AND (NOT CMAKE_CROSSCOMPILING)) +mxnet_option(USE_MKLDNN "Build with MKL-DNN support" ON IF USE_MKL_IF_AVAILABLE AND (NOT APPLE) AND (NOT MSVC) AND (CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86_64") AND (NOT CMAKE_CROSSCOMPILING)) mxnet_option(USE_OPERATOR_TUNING "Enable auto-tuning of operators" ON IF NOT MSVC) mxnet_option(USE_GPERFTOOLS "Build with GPerfTools support" OFF) mxnet_option(USE_JEMALLOC "Build with Jemalloc support" ON) @@ -257,25 +256,22 @@ if(ENABLE_TESTCOVERAGE) endif() if(USE_MKLDNN) - include(cmake/DownloadMKLML.cmake) # CPU architecture (e.g., C5) can't run on another architecture (e.g., g3). - if(NOT MSVC) - set(ARCH_OPT_FLAGS "-mtune=generic") - else() + if(MSVC) set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /EHsc") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /EHsc /Gy") endif() - set(WITH_TEST OFF CACHE INTERNAL "" FORCE) - set(WITH_EXAMPLE OFF CACHE INTERNAL "" FORCE) - set(ARCH_OPT_FLAGS "" CACHE INTERNAL "" FORCE) + set(MKLDNN_BUILD_TESTS OFF CACHE INTERNAL "" FORCE) + set(MKLDNN_BUILD_EXAMPLES OFF CACHE INTERNAL "" FORCE) + set(MKLDNN_ARCH_OPT_FLAGS "" CACHE INTERNAL "" FORCE) + set(MKLDNN_USE_MKL NONE CACHE INTERNAL "" FORCE) + set(MKLDNN_ENABLE_JIT_PROFILING OFF CACHE INTERNAL "" FORCE) add_subdirectory(3rdparty/mkldnn) include_directories(3rdparty/mkldnn/include) include_directories(${PROJECT_BINARY_DIR}/3rdparty/mkldnn/include) - add_definitions(-DUSE_MKL=1) - add_definitions(-DCUB_MKL=1) add_definitions(-DMXNET_USE_MKLDNN=1) list(APPEND mxnet_LINKER_LIBS mkldnn) endif() diff --git a/LICENSE b/LICENSE index 4532449470e8..877e458ef7dd 100644 --- a/LICENSE +++ b/LICENSE @@ -651,43 +651,8 @@ ======================================================================================= - - 13. MKL BLAS - For details, see, [Intel® Simplified license](https://software.intel.com/en-us/license/intel-simplified-software-license) and MKLDNN_README.md - - Copyright (c) 2018 Intel Corporation. - - Use and Redistribution. You may use and redistribute the software (the “Software”), without modification, provided the following conditions are met: - - * Redistributions must reproduce the above copyright notice and the following terms of use in the Software and in the documentation and/or other materials provided with the distribution. - - * Neither the name of Intel nor the names of its suppliers may be used to endorse or promote products derived from this Software without specific prior written permission. - - * No reverse engineering, decompilation, or disassembly of this Software is permitted. - - Limited patent license. Intel grants you a world-wide, royalty-free, non-exclusive license under patents it now or hereafter owns or controls to make, have made, use, import, offer to sell and sell (“Utilize”) this Software, but solely to the extent that any such patent is necessary to Utilize the Software alone. The patent license shall not apply to any combinations which include this software. No hardware per se is licensed hereunder. - - Third party and other Intel programs. “Third Party Programs” are the files listed in the “third-party-programs.txt” text file that is included with the Software and may include Intel programs under separate license terms. Third Party Programs, even if included with the distribution of the Materials, are governed by separate license terms and those license terms solely govern your use of those programs. - - DISCLAIMER. THIS SOFTWARE IS PROVIDED "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE DISCLAIMED. THIS SOFTWARE IS NOT INTENDED FOR USE IN SYSTEMS OR APPLICATIONS WHERE FAILURE OF THE SOFTWARE MAY CAUSE PERSONAL INJURY OR DEATH AND YOU AGREE THAT YOU ARE FULLY RESPONSIBLE FOR ANY CLAIMS, COSTS, DAMAGES, EXPENSES, AND ATTORNEYS’ FEES ARISING OUT OF ANY SUCH USE, EVEN IF ANY CLAIM ALLEGES THAT INTEL WAS NEGLIGENT REGARDING THE DESIGN OR MANUFACTURE OF THE MATERIALS. - - LIMITATION OF LIABILITY. IN NO EVENT WILL INTEL BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. YOU AGREE TO INDEMNIFY AND HOLD INTEL HARMLESS AGAINST ANY CLAIMS AND EXPENSES RESULTING FROM YOUR USE OR UNAUTHORIZED USE OF THE SOFTWARE. - - No support. Intel may make changes to the Software, at any time without notice, and is not obligated to support, update or provide training for the Software. - - Termination. Intel may terminate your right to use the Software in the event of your breach of this Agreement and you fail to cure the breach within a reasonable period of time. - - Feedback. Should you provide Intel with comments, modifications, corrections, enhancements or other input (“Feedback”) related to the Software Intel will be free to use, disclose, reproduce, license or otherwise distribute or exploit the Feedback in its sole discretion without any obligations or restrictions of any kind, including without limitation, intellectual property rights or licensing obligations. - - Compliance with laws. You agree to comply with all relevant laws and regulations governing your use, transfer, import or export (or prohibition thereof) of the Software. - - Governing law. All disputes will be governed by the laws of the United States of America and the State of Delaware without reference to conflict of law principles and subject to the exclusive jurisdiction of the state or federal courts sitting in the State of Delaware, and each party agrees that it submits to the personal jurisdiction and venue of those courts and waives any objections. The United Nations Convention on Contracts for the International Sale of Goods (1980) is specifically excluded and will not apply to the Software. - - *Other names and brands may be claimed as the property of others. - - ======================================================================================= - - 14. FindJeMalloc.cmake + + 13. FindJeMalloc.cmake For details, see cmake/Modules/FindJeMalloc.cmake This file is based on https://github.com/STEllAR-GROUP/hpx/blob/master/cmake/FindJemalloc.cmake @@ -778,7 +743,7 @@ ======================================================================================= - 15. FindPythonLibsNew.cmake + 14. FindPythonLibsNew.cmake For details, see 3rdparty/onnx-tensorrt/third_party/onnx/third_party/pybind11/tools/FindPythonLibsNew.cmake @@ -817,7 +782,7 @@ ======================================================================================= - 16. erfinv-inl.h + 15. erfinv-inl.h For details, see /src/operator/contrib/erfinv-inl.h @@ -860,7 +825,7 @@ ======================================================================================= - 17. mersenne.h + 16. mersenne.h For details, see /3rdparty/nvidia_cub/test/mersenne.h @@ -909,7 +874,7 @@ ======================================================================================= - 18. FindEigen3.cmake + 17. FindEigen3.cmake For details, see /3rdparty/onnx-tensorrt/third_party/onnx/third_party/pybind11/tools/FindEigen3.cmake @@ -920,7 +885,7 @@ ======================================================================================= - 19. protoc-gen-mypy.py + 18. protoc-gen-mypy.py For details, see /3rdparty/onnx-tensorrt/third_party/onnx/tools/protoc-gen-mypy.py @@ -936,7 +901,7 @@ ======================================================================================= - 20. rang + 19. rang For details, see /3rdparty/tvm/3rdparty/rang/LICENSE diff --git a/Makefile b/Makefile index 63a978d01d8a..4746cc434de2 100644 --- a/Makefile +++ b/Makefile @@ -84,8 +84,6 @@ endif ifeq ($(USE_MKLDNN), 1) MKLDNNROOT = $(ROOTDIR)/3rdparty/mkldnn/build/install - MKLROOT = $(ROOTDIR)/3rdparty/mkldnn/build/install - export USE_MKLML = 1 endif include $(TPARTYDIR)/mshadow/make/mshadow.mk @@ -151,14 +149,9 @@ endif ifeq ($(USE_MKLDNN), 1) CFLAGS += -DMXNET_USE_MKLDNN=1 - CFLAGS += -DUSE_MKL=1 CFLAGS += -I$(ROOTDIR)/src/operator/nn/mkldnn/ - ifneq ($(MKLDNNROOT), $(MKLROOT)) - CFLAGS += -I$(MKLROOT)/include - LDFLAGS += -L$(MKLROOT)/lib - endif CFLAGS += -I$(MKLDNNROOT)/include - LDFLAGS += -L$(MKLDNNROOT)/lib -lmkldnn -Wl,-rpath,'$${ORIGIN}' + LDFLAGS += -L$(MKLDNNROOT)/lib -L$(MKLDNNROOT)/lib64 -lmkldnn -Wl,-rpath,'$${ORIGIN}' endif # setup opencv @@ -604,9 +597,7 @@ lib/libmxnet.so: $(ALLX_DEP) -Wl,${WHOLE_ARCH} $(filter %libnnvm.a, $^) -Wl,${NO_WHOLE_ARCH} ifeq ($(USE_MKLDNN), 1) ifeq ($(UNAME_S), Darwin) - install_name_tool -change '@rpath/libmklml.dylib' '@loader_path/libmklml.dylib' $@ - install_name_tool -change '@rpath/libiomp5.dylib' '@loader_path/libiomp5.dylib' $@ - install_name_tool -change '@rpath/libmkldnn.0.dylib' '@loader_path/libmkldnn.0.dylib' $@ + install_name_tool -change '@rpath/libmkldnn.1.dylib' '@loader_path/libmkldnn.1.dylib' $@ endif endif @@ -698,10 +689,8 @@ rpkg: cp src/io/image_recordio.h R-package/src cp -rf lib/libmxnet.so R-package/inst/libs - if [ -e "lib/libmkldnn.so.0" ]; then \ - cp -rf lib/libmkldnn.so.0 R-package/inst/libs; \ - cp -rf lib/libiomp5.so R-package/inst/libs; \ - cp -rf lib/libmklml_intel.so R-package/inst/libs; \ + if [ -e "lib/libmkldnn.so.1" ]; then \ + cp -rf lib/libmkldnn.so.1 R-package/inst/libs; \ fi if [ -e "lib/libtvm_runtime.so" ]; then \ diff --git a/ci/docker/Dockerfile.build.centos7_cpu b/ci/docker/Dockerfile.build.centos7_cpu index e2802aa2fb2b..0cfa5a9f6e47 100644 --- a/ci/docker/Dockerfile.build.centos7_cpu +++ b/ci/docker/Dockerfile.build.centos7_cpu @@ -30,8 +30,6 @@ COPY install/centos7_python.sh /work/ RUN /work/centos7_python.sh COPY install/centos7_scala.sh /work/ RUN /work/centos7_scala.sh -COPY install/ubuntu_mklml.sh /work/ -RUN /work/ubuntu_mklml.sh ARG USER_ID=0 COPY install/centos7_adduser.sh /work/ diff --git a/ci/docker/Dockerfile.build.ubuntu_build_cuda b/ci/docker/Dockerfile.build.ubuntu_build_cuda index e085c2dc09a0..ce6d0007875e 100644 --- a/ci/docker/Dockerfile.build.ubuntu_build_cuda +++ b/ci/docker/Dockerfile.build.ubuntu_build_cuda @@ -42,8 +42,6 @@ COPY install/ubuntu_perl.sh /work/ RUN /work/ubuntu_perl.sh COPY install/ubuntu_clang.sh /work/ RUN /work/ubuntu_clang.sh -COPY install/ubuntu_mklml.sh /work/ -RUN /work/ubuntu_mklml.sh COPY install/ubuntu_ar.sh /work/ RUN /work/ubuntu_ar.sh diff --git a/ci/docker/Dockerfile.build.ubuntu_cpu b/ci/docker/Dockerfile.build.ubuntu_cpu index f41d629289a7..b1eb89bb3f36 100644 --- a/ci/docker/Dockerfile.build.ubuntu_cpu +++ b/ci/docker/Dockerfile.build.ubuntu_cpu @@ -58,9 +58,6 @@ RUN /work/ubuntu_gcc8.sh COPY install/ubuntu_mkl.sh /work/ RUN /work/ubuntu_mkl.sh -COPY install/ubuntu_mklml.sh /work/ -RUN /work/ubuntu_mklml.sh - COPY install/ubuntu_caffe.sh /work/ RUN /work/ubuntu_caffe.sh diff --git a/ci/docker/Dockerfile.build.ubuntu_cpu_julia b/ci/docker/Dockerfile.build.ubuntu_cpu_julia index 108869b680cd..b1eb89bb3f36 100644 --- a/ci/docker/Dockerfile.build.ubuntu_cpu_julia +++ b/ci/docker/Dockerfile.build.ubuntu_cpu_julia @@ -58,9 +58,6 @@ RUN /work/ubuntu_gcc8.sh COPY install/ubuntu_mkl.sh /work/ RUN /work/ubuntu_mkl.sh -COPY install/ubuntu_mklml.sh /work/ -RUN /work/ubuntu_mklml.sh - COPY install/ubuntu_caffe.sh /work/ RUN /work/ubuntu_caffe.sh @@ -78,4 +75,4 @@ RUN /work/ubuntu_adduser.sh COPY runtime_functions.sh /work/ -WORKDIR /work/mxnet \ No newline at end of file +WORKDIR /work/mxnet diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_cu100 b/ci/docker/Dockerfile.build.ubuntu_gpu_cu100 index 894930bc6303..514f6bb1495b 100644 --- a/ci/docker/Dockerfile.build.ubuntu_gpu_cu100 +++ b/ci/docker/Dockerfile.build.ubuntu_gpu_cu100 @@ -46,9 +46,6 @@ RUN /work/ubuntu_perl.sh COPY install/ubuntu_clang.sh /work/ RUN /work/ubuntu_clang.sh -COPY install/ubuntu_mklml.sh /work/ -RUN /work/ubuntu_mklml.sh - COPY install/ubuntu_tvm.sh /work/ RUN /work/ubuntu_tvm.sh diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_cu101 b/ci/docker/Dockerfile.build.ubuntu_gpu_cu101 index 9699b37aa45f..7e0f8d93ed37 100644 --- a/ci/docker/Dockerfile.build.ubuntu_gpu_cu101 +++ b/ci/docker/Dockerfile.build.ubuntu_gpu_cu101 @@ -46,9 +46,6 @@ RUN /work/ubuntu_perl.sh COPY install/ubuntu_clang.sh /work/ RUN /work/ubuntu_clang.sh -COPY install/ubuntu_mklml.sh /work/ -RUN /work/ubuntu_mklml.sh - COPY install/ubuntu_tvm.sh /work/ RUN /work/ubuntu_tvm.sh diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_cu80 b/ci/docker/Dockerfile.build.ubuntu_gpu_cu80 index a1031af811cf..83f05fddf261 100644 --- a/ci/docker/Dockerfile.build.ubuntu_gpu_cu80 +++ b/ci/docker/Dockerfile.build.ubuntu_gpu_cu80 @@ -46,9 +46,6 @@ RUN /work/ubuntu_perl.sh COPY install/ubuntu_clang.sh /work/ RUN /work/ubuntu_clang.sh -COPY install/ubuntu_mklml.sh /work/ -RUN /work/ubuntu_mklml.sh - COPY install/ubuntu_tvm.sh /work/ RUN /work/ubuntu_tvm.sh diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_cu90 b/ci/docker/Dockerfile.build.ubuntu_gpu_cu90 index 56ebd55c94e0..579ad7bffddf 100644 --- a/ci/docker/Dockerfile.build.ubuntu_gpu_cu90 +++ b/ci/docker/Dockerfile.build.ubuntu_gpu_cu90 @@ -46,9 +46,6 @@ RUN /work/ubuntu_perl.sh COPY install/ubuntu_clang.sh /work/ RUN /work/ubuntu_clang.sh -COPY install/ubuntu_mklml.sh /work/ -RUN /work/ubuntu_mklml.sh - COPY install/ubuntu_tvm.sh /work/ RUN /work/ubuntu_tvm.sh diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_cu92 b/ci/docker/Dockerfile.build.ubuntu_gpu_cu92 index f6008a5c09ca..dc125aee371d 100644 --- a/ci/docker/Dockerfile.build.ubuntu_gpu_cu92 +++ b/ci/docker/Dockerfile.build.ubuntu_gpu_cu92 @@ -46,9 +46,6 @@ RUN /work/ubuntu_perl.sh COPY install/ubuntu_clang.sh /work/ RUN /work/ubuntu_clang.sh -COPY install/ubuntu_mklml.sh /work/ -RUN /work/ubuntu_mklml.sh - COPY install/ubuntu_tvm.sh /work/ RUN /work/ubuntu_tvm.sh diff --git a/ci/docker/Dockerfile.build.ubuntu_nightly_cpu b/ci/docker/Dockerfile.build.ubuntu_nightly_cpu index 8e36a74cbde7..5717df1b9130 100644 --- a/ci/docker/Dockerfile.build.ubuntu_nightly_cpu +++ b/ci/docker/Dockerfile.build.ubuntu_nightly_cpu @@ -46,9 +46,6 @@ RUN /work/ubuntu_perl.sh COPY install/ubuntu_clang.sh /work/ RUN /work/ubuntu_clang.sh -COPY install/ubuntu_mklml.sh /work/ -RUN /work/ubuntu_mklml.sh - COPY install/ubuntu_caffe.sh /work/ RUN /work/ubuntu_caffe.sh diff --git a/ci/docker/Dockerfile.build.ubuntu_nightly_gpu b/ci/docker/Dockerfile.build.ubuntu_nightly_gpu index 7d051b1bd04a..5e812c433b43 100644 --- a/ci/docker/Dockerfile.build.ubuntu_nightly_gpu +++ b/ci/docker/Dockerfile.build.ubuntu_nightly_gpu @@ -46,9 +46,6 @@ RUN /work/ubuntu_perl.sh COPY install/ubuntu_clang.sh /work/ RUN /work/ubuntu_clang.sh -COPY install/ubuntu_mklml.sh /work/ -RUN /work/ubuntu_mklml.sh - COPY install/ubuntu_tvm.sh /work/ RUN /work/ubuntu_tvm.sh diff --git a/ci/docker/install/ubuntu_mklml.sh b/ci/docker/install/ubuntu_mklml.sh deleted file mode 100755 index 99fd0b9e01f2..000000000000 --- a/ci/docker/install/ubuntu_mklml.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env bash - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# build and install are separated so changes to build don't invalidate -# the whole docker cache for the image - -set -ex -wget -q --no-check-certificate -O /tmp/mklml.tgz https://github.com/intel/mkl-dnn/releases/download/v0.21/mklml_lnx_2019.0.5.20190502.tgz -tar -zxf /tmp/mklml.tgz && cp -rf mklml_*/* /usr/local/ && rm -rf mklml_* diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index b53db3f980f1..581bb2fd5280 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -692,6 +692,7 @@ build_ubuntu_cpu_mkldnn_mkl() { USE_TVM_OP=1 \ USE_BLAS=mkl \ USE_SIGNAL_HANDLER=1 \ + USE_INTEL_PATH=/opt/intel/ \ -j$(nproc) } @@ -877,9 +878,9 @@ build_ubuntu_gpu_cmake_mkldnn() { /work/mxnet ninja -v - # libmkldnn.so.0 is a link file. We need an actual binary file named libmkldnn.so.0. - cp 3rdparty/mkldnn/src/libmkldnn.so.0 3rdparty/mkldnn/src/libmkldnn.so.0.tmp - mv 3rdparty/mkldnn/src/libmkldnn.so.0.tmp 3rdparty/mkldnn/src/libmkldnn.so.0 + # libmkldnn.so.1 is a link file. We need an actual binary file named libmkldnn.so.1. + cp 3rdparty/mkldnn/src/libmkldnn.so.1 3rdparty/mkldnn/src/libmkldnn.so.1.tmp + mv 3rdparty/mkldnn/src/libmkldnn.so.1.tmp 3rdparty/mkldnn/src/libmkldnn.so.1 } build_ubuntu_gpu_cmake() { diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index d7c2b9679ca3..0770320f1407 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -35,8 +35,8 @@ mx_cmake_lib_no_tvm_op = 'build/libmxnet.so, build/libmxnet.a, build/libsample_l mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' // mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default. mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' -mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0' -mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.1' +mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libmkldnn.so.1, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' diff --git a/cmake/DownloadMKLML.cmake b/cmake/DownloadMKLML.cmake deleted file mode 100644 index bcb2fb408fd4..000000000000 --- a/cmake/DownloadMKLML.cmake +++ /dev/null @@ -1,78 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This script will download MKLML - -message(STATUS "Downloading MKLML...") - -set(MKLDNN_RELEASE v0.21) -set(MKLML_RELEASE_FILE_SUFFIX 2019.0.5.20190502) - -set(MKLML_LNX_MD5 dfcea335652dbf3518e1d02cab2cea97) -set(MKLML_WIN_MD5 ff8c5237570f03eea37377ccfc95a08a) -set(MKLML_MAC_MD5 0a3d83ec1fed9ea318e8573bb5e14c24) - -if(MSVC) - set(MKL_NAME "mklml_win_${MKLML_RELEASE_FILE_SUFFIX}") - - file(DOWNLOAD "https://github.com/intel/mkl-dnn/releases/download/${MKLDNN_RELEASE}/${MKL_NAME}.zip" - "${CMAKE_CURRENT_BINARY_DIR}/mklml/${MKL_NAME}.zip" - EXPECTED_MD5 "${MKLML_WIN_MD5}" SHOW_PROGRESS) - file(DOWNLOAD "https://github.com/apache/incubator-mxnet/releases/download/utils/7z.exe" - "${CMAKE_CURRENT_BINARY_DIR}/mklml/7z2.exe" - EXPECTED_MD5 "E1CF766CF358F368EC97662D06EA5A4C" SHOW_PROGRESS) - - execute_process(COMMAND "${CMAKE_CURRENT_BINARY_DIR}/mklml/7z2.exe" "-o${CMAKE_CURRENT_BINARY_DIR}/mklml/" "-y") - execute_process(COMMAND "${CMAKE_CURRENT_BINARY_DIR}/mklml/7z.exe" - "x" "${CMAKE_CURRENT_BINARY_DIR}/mklml/${MKL_NAME}.zip" "-o${CMAKE_CURRENT_BINARY_DIR}/mklml/" "-y") - - set(MKLROOT "${CMAKE_CURRENT_BINARY_DIR}/mklml/${MKL_NAME}") - - message(STATUS "Setting MKLROOT path to ${MKLROOT}") - - include_directories(${MKLROOT}/include) - -elseif(APPLE) - set(MKL_NAME "mklml_mac_${MKLML_RELEASE_FILE_SUFFIX}") - - file(DOWNLOAD "https://github.com/intel/mkl-dnn/releases/download/${MKLDNN_RELEASE}/${MKL_NAME}.tgz" - "${CMAKE_CURRENT_BINARY_DIR}/mklml/${MKL_NAME}.tgz" - EXPECTED_MD5 "${MKLML_MAC_MD5}" SHOW_PROGRESS) - execute_process(COMMAND "tar" "-xzf" "${CMAKE_CURRENT_BINARY_DIR}/mklml/${MKL_NAME}.tgz" - "-C" "${CMAKE_CURRENT_BINARY_DIR}/mklml/") - - set(MKLROOT "${CMAKE_CURRENT_BINARY_DIR}/mklml/${MKL_NAME}") - - message(STATUS "Setting MKLROOT path to ${MKLROOT}") - include_directories(${MKLROOT}/include) - -elseif(UNIX) - set(MKL_NAME "mklml_lnx_${MKLML_RELEASE_FILE_SUFFIX}") - - file(DOWNLOAD "https://github.com/intel/mkl-dnn/releases/download/${MKLDNN_RELEASE}/${MKL_NAME}.tgz" - "${CMAKE_CURRENT_BINARY_DIR}/mklml/${MKL_NAME}.tgz" - EXPECTED_MD5 "${MKLML_LNX_MD5}" SHOW_PROGRESS) - execute_process(COMMAND "tar" "-xzf" "${CMAKE_CURRENT_BINARY_DIR}/mklml/${MKL_NAME}.tgz" - "-C" "${CMAKE_CURRENT_BINARY_DIR}/mklml/") - - set(MKLROOT "${CMAKE_CURRENT_BINARY_DIR}/mklml/${MKL_NAME}") - message(STATUS "Setting MKLROOT path to ${MKLROOT}") - include_directories(${MKLROOT}/include) - -else() - message(FATAL_ERROR "Wrong platform") -endif() diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 176aa0aaa197..1b0b119a02ac 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -709,7 +709,7 @@ class NDArray { * Create NDArray from mkldnn memory descriptor. * mem_pd The mkldnn memory descriptor to be created. */ - explicit NDArray(mkldnn::memory::primitive_desc mem_pd); + explicit NDArray(const mkldnn::memory::desc &md); /* * Test if the data is stored in one of special MKLDNN format. */ @@ -737,15 +737,14 @@ class NDArray { * This function returns mkldnn::memory with the given primitive_desc * as long as the array size meets the required size in the given primitive_desc. */ - const mkldnn::memory *GetMKLDNNData( - const mkldnn::memory::primitive_desc &desc) const; + const mkldnn::memory *GetMKLDNNData(const mkldnn::memory::desc &md) const; /* * This function returns mkldnn::memory with the given primitive_desc. * The returned mkldnn::memory will have the same physical layout as * the given primitive_desc. */ const mkldnn::memory *GetMKLDNNDataReorder( - const mkldnn::memory::primitive_desc &desc) const; + const mkldnn::memory::desc &md) const; /* * This function copies data from mkldnn memory. @@ -755,16 +754,15 @@ class NDArray { * This function allocates memory for array and creates mkldnn memory * with the specified format. */ - mkldnn::memory *CreateMKLDNNData( - const mkldnn::memory::primitive_desc &desc); + mkldnn::memory *CreateMKLDNNData(const mkldnn::memory::desc &md); /* * These are the async version of the methods above. * It changes the layout of this NDArray, but it happens after all accesses to * the array are complete. */ - void Reorder2DefaultAsync(); - void MKLDNNDataReorderAsync(const mkldnn::memory::primitive_desc &desc); + void Reorder2DefaultAsync() const; + void MKLDNNDataReorderAsync(const mkldnn::memory::desc &md) const; /* * This creates a new NDArray with the reordered data. @@ -789,7 +787,7 @@ class NDArray { /*! * \ Fix mkldnn memory descriptor mismatch from NDArray. */ - void UpdateMKLDNNMemDesc(mkldnn::memory::format format); + void UpdateMKLDNNMemDesc(const mkldnn::memory::desc &desc); #endif /*! @@ -1054,7 +1052,7 @@ class NDArray { // save the result in shandle. void Reorder2Default(); // Reroder data to a specified layout. - void MKLDNNDataReorder(const mkldnn::memory::primitive_desc &desc); + void MKLDNNDataReorder(const mkldnn::memory::desc &md); bool IsMKLDNN() const; bool IsDefault() const; #endif diff --git a/mkldnn.mk b/mkldnn.mk index 802f3dc747c2..bc2190018bdf 100644 --- a/mkldnn.mk +++ b/mkldnn.mk @@ -21,15 +21,11 @@ ifeq ($(USE_MKLDNN), 1) MXNET_LIBDIR = $(ROOTDIR)/lib MXNET_INCLDIR = $(ROOTDIR)/include ifeq ($(UNAME_S), Darwin) - OMP_LIBFILE = $(MKLDNNROOT)/lib/libiomp5.dylib - MKLML_LIBFILE = $(MKLDNNROOT)/lib/libmklml.dylib - MKLDNN_LIBFILE = $(MKLDNNROOT)/lib/libmkldnn.0.dylib - MKLDNN_LIB64FILE = $(MKLDNNROOT)/lib64/libmkldnn.0.dylib + MKLDNN_LIBFILE = $(MKLDNNROOT)/lib/libmkldnn.1.dylib + MKLDNN_LIB64FILE = $(MKLDNNROOT)/lib64/libmkldnn.1.dylib else - OMP_LIBFILE = $(MKLDNNROOT)/lib/libiomp5.so - MKLML_LIBFILE = $(MKLDNNROOT)/lib/libmklml_intel.so - MKLDNN_LIBFILE = $(MKLDNNROOT)/lib/libmkldnn.so.0 - MKLDNN_LIB64FILE = $(MKLDNNROOT)/lib64/libmkldnn.so.0 + MKLDNN_LIBFILE = $(MKLDNNROOT)/lib/libmkldnn.so.1 + MKLDNN_LIB64FILE = $(MKLDNNROOT)/lib64/libmkldnn.so.1 endif endif @@ -38,22 +34,21 @@ endif mkldnn_build: $(MKLDNN_LIBFILE) $(MKLDNN_LIBFILE): - mkdir -p $(MKLDNNROOT) - cd $(MKLDNN_SUBMODDIR) && rm -rf external && cd scripts && ./prepare_mkl.sh && cd .. && cp -a external/*/* $(MKLDNNROOT)/. - cmake $(MKLDNN_SUBMODDIR) -DCMAKE_INSTALL_PREFIX=$(MKLDNNROOT) -B$(MKLDNN_BUILDDIR) -DARCH_OPT_FLAGS="-mtune=generic" -DWITH_TEST=OFF -DWITH_EXAMPLE=OFF + mkdir -p $(MKLDNNROOT)/lib + cmake $(MKLDNN_SUBMODDIR) -DCMAKE_INSTALL_PREFIX=$(MKLDNNROOT) -B$(MKLDNN_BUILDDIR) -DMKLDNN_ARCH_OPT_FLAGS="" -DMKLDNN_BUILD_TESTS=OFF -DMKLDNN_BUILD_EXAMPLES=OFF -DMKLDNN_ENABLE_JIT_PROFILING=OFF $(MAKE) -C $(MKLDNN_BUILDDIR) VERBOSE=1 $(MAKE) -C $(MKLDNN_BUILDDIR) install + mkdir -p $(MXNET_LIBDIR) if [ -f "$(MKLDNN_LIB64FILE)" ]; then \ - mv $(MKLDNNROOT)/lib64/libmkldnn* $(MKLDNNROOT)/lib/; \ + cp $(MKLDNNROOT)/lib64/libmkldnn* $(MXNET_LIBDIR); \ + cp $(MKLDNNROOT)/lib64/libmkldnn* $(MKLDNNROOT)/lib/; \ + else \ + cp $(MKLDNNROOT)/lib/libmkldnn* $(MXNET_LIBDIR); \ fi - mkdir -p $(MXNET_LIBDIR) - cp $(OMP_LIBFILE) $(MXNET_LIBDIR) - cp $(MKLML_LIBFILE) $(MXNET_LIBDIR) - cp $(MKLDNN_LIBFILE) $(MXNET_LIBDIR) cp $(MKLDNN_BUILDDIR)/include/mkldnn_version.h $(MXNET_INCLDIR)/mkldnn/. + mkldnn_clean: $(RM) -r 3rdparty/mkldnn/build - $(RM) -r $(MKLDNNROOT) ifeq ($(USE_MKLDNN), 1) mkldnn: mkldnn_build diff --git a/scala-package/assembly/src/main/assembly/assembly.xml b/scala-package/assembly/src/main/assembly/assembly.xml index bcc5408cd3db..060a97b82064 100644 --- a/scala-package/assembly/src/main/assembly/assembly.xml +++ b/scala-package/assembly/src/main/assembly/assembly.xml @@ -57,12 +57,8 @@ libtvm_runtime.so libgfortran.so.3 libquadmath.so.0 - libiomp5.so - libiomp5.dylib - libmklml_intel.so - libmklml.dylib - libmkldnn.so.0 - libmkldnn.0.dylib + libmkldnn.so.1 + libmkldnn.1.dylib lib/native diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/util/NativeLibraryLoader.scala b/scala-package/core/src/main/scala/org/apache/mxnet/util/NativeLibraryLoader.scala index 9609ba25da40..49e5d685adfe 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/util/NativeLibraryLoader.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/util/NativeLibraryLoader.scala @@ -89,12 +89,8 @@ private[mxnet] object NativeLibraryLoader { saveLibraryToTemp("libtvm_runtime.so", "/lib/native/libtvm_runtime.so", false) saveLibraryToTemp("libgfortran.so.3", "/lib/native/libgfortran.so.3", false) saveLibraryToTemp("libquadmath.so.0", "/lib/native/libquadmath.so.0", false) - saveLibraryToTemp("libiomp5.so", "/lib/native/libiomp5.so", false) - saveLibraryToTemp("libiomp5.dylib", "/lib/native/libiomp5.dylib", false) - saveLibraryToTemp("libmklml_intel.so", "/lib/native/libmklml_intel.so", false) - saveLibraryToTemp("libmklml.dylib", "/lib/native/libmklml.dylib", false) - saveLibraryToTemp("libmkldnn.so.0", "/lib/native/libmkldnn.so.0", false) - saveLibraryToTemp("libmkldnn.0.dylib", "/lib/native/libmkldnn.0.dylib", false) + saveLibraryToTemp("libmkldnn.so.1", "/lib/native/libmkldnn.so.1", false) + saveLibraryToTemp("libmkldnn.1.dylib", "/lib/native/libmkldnn.1.dylib", false) val tempfile: File = saveLibraryToTemp(libname, libFileInJar, true) loadLibraryFromFile(libname, tempfile) diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 8317c6073a24..b609c54b50f3 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -477,8 +477,18 @@ inline void PushFComputeEx(const FComputeEx& fn, // copying A to B may not happen, and will corrupt A's memory. InvalidateOutputs(outputs, req); } -#endif + // add for mkldnn OP + no mkldnn OP + const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); + if (!is_mkldnn.get(attrs.op, false)) { + std::vector inputs_fallback; + CreateDefaultInputs(inputs, &inputs_fallback); + fn(attrs, opctx, inputs_fallback, req, outputs); + } else { + fn(attrs, opctx, inputs, req, outputs); + } +#else fn(attrs, opctx, inputs, req, outputs); +#endif if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync && !rctx.is_bulk) { rctx.get_stream()->Wait(); } @@ -531,8 +541,18 @@ inline void PushOperator(const OpStatePtr& state, // copying A to B may not happen, and will corrupt A's memory. InvalidateOutputs(outputs, req); } -#endif + // add for mkldnn OP + no mkldnn OP + const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); + if (!is_mkldnn.get(attrs.op, false)) { + std::vector inputs_fallback; + CreateDefaultInputs(inputs, &inputs_fallback); + fcompute_ex(state, opctx, inputs_fallback, req, outputs); + } else { + fcompute_ex(state, opctx, inputs, req, outputs); + } +#else fcompute_ex(state, opctx, inputs, req, outputs); +#endif if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync && rctx.get_stream() && !rctx.is_bulk) { rctx.get_stream()->Wait(); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index e1075c9c15da..f0dca2ea2aee 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -31,9 +31,6 @@ #include #include #include -#if MXNET_USE_MKLDNN == 1 -#include -#endif #include "./ndarray_function.h" #include "../common/utils.h" #include "../operator/tensor/matrix_op-inl.h" @@ -182,25 +179,23 @@ nnvm::Symbol NDArray::get_autograd_symbol() const { #if MXNET_USE_MKLDNN == 1 -NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd) +NDArray::NDArray(const mkldnn::memory::desc &md) : storage_type_(kDefaultStorage), entry_(nullptr) { - auto mem_desc = mem_pd.desc(); - shape_ = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); - dtype_ = get_mxnet_type(mem_desc.data.data_type); + shape_ = mxnet::TShape(md.data.dims, md.data.dims + md.data.ndims); + dtype_ = get_mxnet_type(md.data.data_type); ptr_ = std::make_shared(shape_, Context::CPU(), true, dtype_); - ptr_->CheckAndAlloc(mem_pd.get_size()); - ptr_->mkl_mem_ = std::make_shared(mem_pd, ptr_->shandle.dptr); + ptr_->CheckAndAlloc(md.get_size()); + ptr_->mkl_mem_ = std::make_shared(md, ptr_->shandle.dptr); } NDArray::NDArray(const std::shared_ptr &mkldnn_mem) : storage_type_(kDefaultStorage), entry_(nullptr) { - auto mem_pd = mkldnn_mem->get_primitive_desc(); - auto mem_desc = mem_pd.desc(); + auto mem_desc = mkldnn_mem->get_desc(); shape_ = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); dtype_ = get_mxnet_type(mem_desc.data.data_type); ptr_ = std::make_shared(shape_, Context::CPU(), true, dtype_); ptr_->shandle.dptr = mkldnn_mem->get_data_handle(); - ptr_->shandle.size = mem_pd.get_size(); + ptr_->shandle.size = mem_desc.get_size(); ptr_->delay_alloc = false; ptr_->mkl_mem_ = std::make_shared(mkldnn_mem); ptr_->static_data = true; @@ -219,22 +214,24 @@ NDArray NDArray::MKLDNNDataReshape(const mxnet::TShape &shape) const { NDArray ret(shape, ctx(), true, dtype()); // We shouldn't submit the reorder primitive here because submit will // be called in operators. - mkldnn_memory_format_t format = ptr_->mkl_mem_->GetDefaultFormat(); - CHECK_NE(format, ptr_->mkl_mem_->GetFormat()); - mkldnn::memory::primitive_desc def_pd = ptr_->mkl_mem_->GetPrimitiveDesc(format); - mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd); + mkldnn_format_tag_t format = ptr_->mkl_mem_->GetDefaultFormat(); + CHECK(ptr_->IsMKLDNN()); + mkldnn::memory::desc def_desc = ptr_->mkl_mem_->GetDesc(format); + mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_desc); MKLDNNStream *stream = MKLDNNStream::Get(); std::shared_ptr curr_mem = ptr_->mkl_mem_->GetMem(); stream->RegisterMem(curr_mem); - stream->RegisterPrim(mkldnn::reorder(*curr_mem, *def_mem)); + std::unordered_map args({{MKLDNN_ARG_FROM, *curr_mem}, + {MKLDNN_ARG_TO, *def_mem}}); + stream->RegisterPrimArgs(mkldnn::reorder(*curr_mem, *def_mem), args); // def_mem points to a memory region in the temp space. It's only valid // inside an operator. As such, the returned NDArray can only be valid // inside an operator and the shared point doesn't need to do anything // when it's destroyed. - auto tmp = std::shared_ptr(def_mem, [](mkldnn::memory *mem){}); + auto tmp = std::shared_ptr(def_mem, [](mkldnn::memory *mem) {}); ret.ptr_->mkl_mem_.reset(new MKLDNNMemory(tmp)); ret.ptr_->shandle.dptr = def_mem->get_data_handle(); - ret.ptr_->shandle.size = def_mem->get_primitive_desc().get_size(); + ret.ptr_->shandle.size = def_mem->get_desc().get_size(); ret.ptr_->delay_alloc = false; ret.ptr_->static_data = true; ret.byte_offset_ = byte_offset_; @@ -242,7 +239,6 @@ NDArray NDArray::MKLDNNDataReshape(const mxnet::TShape &shape) const { return ret; } } - #endif NDArray NDArray::Reshape(const mxnet::TShape &shape) const { @@ -420,57 +416,56 @@ void NDArray::Chunk::Reorder2Default() { if (mkl_mem_ == nullptr) return; - mkldnn_memory_format_t format = mkl_mem_->GetDefaultFormat(); - if (format == mkl_mem_->GetFormat()) + if (IsDefault()) return; - mkldnn::memory::primitive_desc def_pd = mkl_mem_->GetPrimitiveDesc(format); - mkldnn_mem_ptr def_mem(new mkldnn::memory(def_pd)); + mkldnn_format_tag_t format = mkl_mem_->GetDefaultFormat(); + mkldnn::memory::desc def_desc = mkl_mem_->GetDesc(format); + mkldnn_mem_ptr def_mem(new mkldnn::memory(def_desc, CpuEngine::Get()->get_engine())); mkl_mem_->ReorderTo(def_mem.get()); - CHECK(shandle.size >= def_pd.get_size()); - CheckAndAlloc(def_pd.get_size()); + CHECK(shandle.size >= def_desc.get_size()); + CheckAndAlloc(def_desc.get_size()); // TODO(zhengda) We need to avoid memory copy here. - memcpy(shandle.dptr, def_mem->get_data_handle(), def_pd.get_size()); + memcpy(shandle.dptr, def_mem->get_data_handle(), def_desc.get_size()); mkl_mem_ = nullptr; } -void NDArray::Chunk::MKLDNNDataReorder(const mkldnn::memory::primitive_desc &pd) { +void NDArray::Chunk::MKLDNNDataReorder(const mkldnn::memory::desc &md) { // If the memory already uses the specified layout, don't do anything. - if (mkl_mem_ != nullptr && mkl_mem_->SameFormat(pd)) + if (mkl_mem_ != nullptr && mkl_mem_->SameFormat(md)) return; - mkldnn::memory::primitive_desc _pd = pd; - mkldnn::memory::desc _desc = _pd.desc(); - mkldnn_memory_format_t def_format = GetDefaultFormat(_desc); + // If the memory is default, don't do anything. - if (def_format == _desc.data.format && IsDefault()) + if (!mxnet::IsMKLDNN(md) && IsDefault()) return; - // If the specified layout is default, we should use Reorder2Default. - if (def_format == _desc.data.format) { + if (!mxnet::IsMKLDNN(md)) { + // If the specified layout is default, we should use Reorder2Default. Reorder2Default(); return; } + auto engine = CpuEngine::Get()->get_engine(); + mkldnn::stream s(engine); - std::shared_ptr new_mem(new mkldnn::memory(pd)); + std::shared_ptr new_mem(new mkldnn::memory(md, engine)); std::shared_ptr old_mem; if (IsDefault()) { - mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(pd, def_format); - old_mem.reset(new mkldnn::memory(def_pd, shandle.dptr)); + mkldnn_format_tag_t def_format = GetDefaultFormat(md); + mkldnn::memory::desc def_desc = GetDesc(md, def_format); + old_mem.reset(new mkldnn::memory(def_desc, engine, shandle.dptr)); } else { old_mem = this->mkl_mem_->GetMem(); } - CHECK(old_mem->get_primitive_desc().desc().data.ndims == _desc.data.ndims); + CHECK(old_mem->get_desc().data.ndims == md.data.ndims); // This may be called in MKLDNN operators. We can't use MKLDNNStream here. - std::vector net; - net.push_back(mkldnn::reorder(*old_mem, *new_mem)); - mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + mkldnn::reorder(*old_mem, *new_mem).execute(s, *old_mem, *new_mem); - CHECK(shandle.size >= pd.get_size()); - CheckAndAlloc(pd.get_size()); + CHECK(shandle.size >= md.get_size()); + CheckAndAlloc(md.get_size()); // TODO(zhengda) We need to avoid memory copy here. - memcpy(shandle.dptr, new_mem->get_data_handle(), pd.get_size()); - mkl_mem_.reset(new MKLDNNMemory(pd, shandle.dptr)); + memcpy(shandle.dptr, new_mem->get_data_handle(), md.get_size()); + mkl_mem_.reset(new MKLDNNMemory(md, shandle.dptr)); } void NDArray::Chunk::SetMKLMem(const mxnet::TShape &shape, int dtype) { @@ -484,50 +479,43 @@ void NDArray::Chunk::SetMKLMem(const mxnet::TShape &shape, int dtype) { mkldnn::memory::dims dims; // These are shapes supprted by MKLDNN. - if (shape.ndim() >= 1 && shape.ndim() <= 5) { + if (shape.ndim() >= 1 && shape.ndim() <= 6) { dims.resize(shape.ndim()); for (size_t i = 0; i < dims.size(); i++) dims[i] = shape[i]; } else { LOG(FATAL) << "MKLDNN doesn't support " << shape.ndim() << " dimensions"; } - mkldnn::memory::format layout = mkldnn::memory::format::format_undef; + mkldnn::memory::format_tag layout = mkldnn::memory::format_tag::undef; switch (dims.size()) { - case 1: layout = mkldnn::memory::format::x; break; - case 2: layout = mkldnn::memory::format::nc; break; - case 3: layout = mkldnn::memory::format::ncw; break; - case 4: layout = mkldnn::memory::format::nchw; break; - // This isn't the right layout when the data has 5 dimensions in MXNet. - // MXNet interprets 5 dimensions as ncdhw, but MKLDNN doesn't have - // a corresponding format. - case 5: layout = mkldnn::memory::format::goihw; break; + case 1: layout = mkldnn::memory::format_tag::a; break; + case 2: layout = mkldnn::memory::format_tag::ab; break; + case 3: layout = mkldnn::memory::format_tag::abc; break; + case 4: layout = mkldnn::memory::format_tag::abcd; break; + case 5: layout = mkldnn::memory::format_tag::abcde; break; + case 6: layout = mkldnn::memory::format_tag::abcdef; break; + default: + LOG(FATAL) << "Not implemented dimension (" << dims.size() << ") for MKLDNN"; } mkldnn::memory::desc data_md{dims, get_mkldnn_type(dtype), layout}; - auto cpu_engine = CpuEngine::Get()->get_engine(); if (shandle.dptr == nullptr) { CHECK(delay_alloc); CheckAndAlloc(); } - mkldnn::memory::primitive_desc pd(data_md, cpu_engine); - CHECK(shandle.size >= pd.get_size()); - mkl_mem_.reset(new MKLDNNMemory(pd, shandle.dptr)); + CHECK(shandle.size >= data_md.get_size()); + mkl_mem_.reset(new MKLDNNMemory(data_md, shandle.dptr)); } -const mkldnn::memory *NDArray::GetMKLDNNData( - const mkldnn::memory::primitive_desc &desc) const { +const mkldnn::memory *NDArray::GetMKLDNNData(const mkldnn::memory::desc &desc) const { if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; return nullptr; } const mkldnn::memory *mem = GetMKLDNNData(); - mkldnn::memory::primitive_desc _desc = desc; - mkldnn::memory::desc desc1 = mem->get_primitive_desc().desc(); - mkldnn::memory::desc desc2 = _desc.desc(); + mkldnn::memory::desc desc1 = mem->get_desc(); // The MKL memory has the same format and shape as required, // or both use the default format, we can return the MKL memory. - if (mem->get_primitive_desc() == desc - || (desc1.data.format == GetDefaultFormat(desc1) - && desc2.data.format == GetDefaultFormat(desc2))) { + if (desc1 == desc || ((!mxnet::IsMKLDNN(desc1)) && (!mxnet::IsMKLDNN(desc)))) { return GetMKLDNNExact(mem, desc); } else { return nullptr; @@ -535,8 +523,8 @@ const mkldnn::memory *NDArray::GetMKLDNNData( } const mkldnn::memory *NDArray::GetMKLDNNDataReorder( - const mkldnn::memory::primitive_desc &new_pd) const { - if (new_pd.get_size() != shape().Size() * GetTypeSize(dtype_)) { + const mkldnn::memory::desc &new_desc) const { + if (new_desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; return nullptr; } @@ -545,39 +533,40 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder( const mkldnn::memory *mem = GetMKLDNNData(); // If the memory descriptor matches, it's easy. MKLDNNStream *stream = MKLDNNStream::Get(); - if (mem->get_primitive_desc() == new_pd) { - return GetMKLDNNExact(mem, new_pd); + if (mem->get_desc() == new_desc) { + return GetMKLDNNExact(mem, new_desc); } - mkldnn::memory::primitive_desc _pd = new_pd; - mkldnn::memory::desc desc1 = mem->get_primitive_desc().desc(); - mkldnn::memory::desc desc2 = _pd.desc(); + mkldnn::memory::desc old_desc = mem->get_desc(); // Now we need to determine if we should reorder the memory. // If both use the default formats, we think we don't need to reorder. - if (desc1.data.format == GetDefaultFormat(desc1) && - desc2.data.format == GetDefaultFormat(desc2)) { - mkldnn_mem_ptr ret(new mkldnn::memory(new_pd, mem->get_data_handle())); + if ((!mxnet::IsMKLDNN(old_desc)) && (!mxnet::IsMKLDNN(new_desc))) { + mkldnn_mem_ptr ret(new mkldnn::memory(new_desc, + CpuEngine::Get()->get_engine(), mem->get_data_handle())); stream->RegisterMem(ret); return ret.get(); - } else if (same_shape(desc1, desc2)) { + } else if (same_shape(old_desc, new_desc)) { // If they have the same shape, we can reorder data directly. - mkldnn::memory *ret = TmpMemMgr::Get()->Alloc(new_pd); - stream->RegisterPrim(mkldnn::reorder(*mem, *ret)); + mkldnn::memory *ret = TmpMemMgr::Get()->Alloc(new_desc); + std::unordered_map args({{MKLDNN_ARG_FROM, *mem}, {MKLDNN_ARG_TO, *ret}}); + stream->RegisterPrimArgs(mkldnn::reorder(*mem, *ret), args); return ret; } else { // If they have different shapes, we need to reshape the array first. // Since this method will only be used inside an operator, we can call // MKLDNNDataReshape to reshape an array. - mxnet::TShape required_shape(desc2.data.ndims, -1); - for (int i = 0; i < desc2.data.ndims; i++) - required_shape[i] = desc2.data.dims[i]; + mxnet::TShape required_shape(new_desc.data.ndims, -1); + for (int i = 0; i < new_desc.data.ndims; i++) + required_shape[i] = new_desc.data.dims[i]; NDArray reshaped = MKLDNNDataReshape(required_shape); const mkldnn::memory *ret = reshaped.GetMKLDNNData(); - if (ret->get_primitive_desc() == new_pd) { - return GetMKLDNNExact(ret, new_pd); + if (ret->get_desc() == new_desc) { + return GetMKLDNNExact(ret, new_desc); } else { - mkldnn::memory *ret2 = TmpMemMgr::Get()->Alloc(new_pd); - stream->RegisterPrim(mkldnn::reorder(*ret, *ret2)); + mkldnn::memory *ret2 = TmpMemMgr::Get()->Alloc(new_desc); + std::unordered_map args({{MKLDNN_ARG_FROM, *ret}, + {MKLDNN_ARG_TO, *ret2}}); + stream->RegisterPrimArgs(mkldnn::reorder(*ret, *ret2), args); return ret2; } } @@ -588,18 +577,18 @@ NDArray NDArray::Reorder2Default() const { if (ptr_->mkl_mem_ == nullptr) return *this; - mkldnn_memory_format_t format = ptr_->mkl_mem_->GetDefaultFormat(); - if (format == ptr_->mkl_mem_->GetFormat()) + if (!ptr_->mkl_mem_->IsMKLDNN()) return *this; // create new ndarray from mkldnn layout - mkldnn::memory::desc from_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc(); + mkldnn::memory::desc from_desc = ptr_->mkl_mem_->GetDesc(); mxnet::TShape tshape(from_desc.data.ndims, -1); for (int i = 0; i < from_desc.data.ndims; i++) tshape[i] = from_desc.data.dims[i]; NDArray ret(tshape, ctx(), false, dtype()); - mkldnn::memory::primitive_desc def_pd = ptr_->mkl_mem_->GetPrimitiveDesc(format); - CHECK(ret.ptr_->shandle.size >= def_pd.get_size()); - mkldnn::memory def_mem(def_pd, ret.ptr_->shandle.dptr); + mkldnn_format_tag_t format = ptr_->mkl_mem_->GetDefaultFormat(); + mkldnn::memory::desc def_desc = ptr_->mkl_mem_->GetDesc(format); + CHECK(ret.ptr_->shandle.size >= def_desc.get_size()); + mkldnn::memory def_mem(def_desc, CpuEngine::Get()->get_engine(), ret.ptr_->shandle.dptr); ptr_->mkl_mem_->ReorderTo(&def_mem); // reshape as needed ret.shape_ = shape_; @@ -608,7 +597,7 @@ NDArray NDArray::Reorder2Default() const { return ret; } -void NDArray::Reorder2DefaultAsync() { +void NDArray::Reorder2DefaultAsync() const { std::vector const_vars; std::vector mutable_vars(1, this->var()); NDArray tmp = *this; @@ -620,21 +609,21 @@ void NDArray::Reorder2DefaultAsync() { FnProperty::kNormal, 0, "Reorder2Default"); } -void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::primitive_desc &desc) { +void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::desc &desc) const { std::vector const_vars; std::vector mutable_vars(1, this->var()); NDArray tmp = *this; const auto version = this->version(); Engine::Get()->PushAsync( - [tmp, version, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) { - // MXNet will try to reuse NDArray from memory planning, so we need to ensure - // the NDArray is still holding the original trunk data. - if (tmp.version() == version) { - tmp.ptr_->MKLDNNDataReorder(desc); - } - on_complete(); - }, - ctx(), const_vars, mutable_vars, FnProperty::kNormal, 0, "Reorder"); + [tmp, version, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) { + // MXNet will try to reuse NDArray from memory planning, so we need to ensure + // the NDArray is still holding the original trunk data. + if (tmp.version() == version) { + tmp.ptr_->MKLDNNDataReorder(desc); + } + on_complete(); + }, ctx(), const_vars, mutable_vars, + FnProperty::kNormal, 0, "Reorder"); } const mkldnn::memory *NDArray::GetMKLDNNData() const { @@ -658,14 +647,12 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const { mkldnn::memory::dims dims(shape().ndim()); for (size_t i = 0; i < dims.size(); i++) dims[i] = shape()[i]; - mkldnn::memory::format cpp_format = static_cast( + mkldnn::memory::format_tag cpp_format = static_cast( GetDefaultFormat(shape().ndim())); mkldnn::memory::data_type cpp_type = get_mkldnn_type(dtype_); mkldnn::memory::desc data_md(dims, cpp_type, cpp_format); - mkldnn::memory::primitive_desc new_pd(data_md, - CpuEngine::Get()->get_engine()); - - std::shared_ptr ret(new mkldnn::memory(new_pd, off_addr)); + std::shared_ptr ret( + new mkldnn::memory(data_md, CpuEngine::Get()->get_engine(), off_addr)); MKLDNNStream::Get()->RegisterMem(ret); return ret.get(); } else { @@ -689,7 +676,7 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) { if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetRaw() == &mem) return; - CHECK(mem.get_primitive_desc().get_size() == shape().Size() * GetTypeSize(dtype_)) + CHECK(mem.get_desc().get_size() == shape().Size() * GetTypeSize(dtype_)) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; // If this array uses MKLDNN layout, we have to make sure it's not a view. // Otherwise, we'll have to change the layout inside the array. @@ -701,28 +688,25 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) { MKLDNNCopy(mem, this_mem); } -mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &desc) { +mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::desc &desc) { if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { - LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; + LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc. " + << "MKLDNN memory requests for " << desc.get_size() << " bytes, but got " + << shape().Size() * GetTypeSize(dtype_) << " bytes from NDArray"; return nullptr; } - - mkldnn::memory::primitive_desc _desc = desc; - mkldnn_memory_format_t required_format = _desc.desc().data.format; - mkldnn_memory_format_t def_format = GetDefaultFormat(_desc.desc()); - // If the required format is a default format, we don't need to worry about the shape. - // If the shape isn't the same, it actually implicitly reshapes data. - if (required_format == def_format && !IsView()) { + bool isDefaultFormat = IsDefaultFormat(desc); + if (isDefaultFormat && !IsView()) { ptr_->SetMKLMem(shape_, dtype_); MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc); - } else if (required_format == def_format) { + } else if (isDefaultFormat) { ptr_->CheckAndAlloc(); CHECK(ptr_->shandle.dptr); // When this is a view and a user wants the default layout, we can simply // create a new mkldnn memory that points to the right memory. - std::shared_ptr mem(new mkldnn::memory( - desc, static_cast(ptr_->shandle.dptr) + byte_offset_)); + std::shared_ptr mem(new mkldnn::memory(desc, + CpuEngine::Get()->get_engine(), static_cast(ptr_->shandle.dptr) + byte_offset_)); MKLDNNStream::Get()->RegisterMem(mem); return mem.get(); } else if (IsView()) { @@ -736,7 +720,7 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc & if (ptr_->mkl_mem_) CHECK(ptr_->mkl_mem_->GetDataHandle() == ptr_->shandle.dptr); - if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetPrimitiveDesc() == desc) { + if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetDesc() == desc) { MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc); } @@ -748,17 +732,14 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc & return ptr_->mkl_mem_->GetRaw(); } -void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format format) { - const mkldnn::memory *mem = GetMKLDNNData(); - auto mem_desc = mem->get_primitive_desc().desc(); +void NDArray::UpdateMKLDNNMemDesc(const mkldnn::memory::desc &desc) { + auto new_desc = desc; auto this_dtype = get_mkldnn_type(dtype()); - mkldnn::memory::desc data_md( - mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims), - this_dtype, format); - mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine()); - ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr)); + new_desc.data.data_type = static_cast(this_dtype); + ptr_->mkl_mem_.reset(new MKLDNNMemory(new_desc, ptr_->shandle.dptr)); MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); } + #endif void NDArray::SetTBlob() const { @@ -1127,9 +1108,9 @@ inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext // by MKLDNN. auto from_mem = from.GetMKLDNNData(); auto to_mem = to.GetMKLDNNData(); - if (from_mem->get_primitive_desc() == to_mem->get_primitive_desc()) { - size_t size = std::min(from_mem->get_primitive_desc().get_size(), - to_mem->get_primitive_desc().get_size()); + if (from_mem->get_desc() == to_mem->get_desc()) { + size_t size = std::min(from_mem->get_desc().get_size(), + to_mem->get_desc().get_size()); memcpy(to_mem->get_data_handle(), from_mem->get_data_handle(), size); } else { const_cast(to).CopyFrom(*from_mem); diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index 9e016bf884f2..fa62b0044a53 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -203,7 +203,7 @@ inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs, dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx); } -#endif +#endif // MXNET_USE_MKLDNN == 1 if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); @@ -214,7 +214,7 @@ inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 if (!MKLDNNEnvSet()) *dispatch_mode = DispatchMode::kFComputeFallback; -#endif +#endif // MXNET_USE_MKLDNN == 1 return dispatched; } @@ -232,12 +232,12 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs, && param.dim > 0) wanted_mode = DispatchMode::kFComputeEx; else -#endif +#endif // MXNET_USE_MKLDNN == 1 wanted_mode = DispatchMode::kFCompute; #if MXNET_USE_MKLDNN == 1 if (!MKLDNNEnvSet()) wanted_mode = DispatchMode::kFComputeFallback; -#endif +#endif // MXNET_USE_MKLDNN == 1 return storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode, wanted_mode); } @@ -249,12 +249,12 @@ bool SupportMKLDNNConcat(const std::vector &arrs) { // DO not support zero-size tensors. if (arr.shape().Size() == 0) return false; int ndim = arr.shape().ndim(); - const int mkldnn_ndims = arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims; + const int mkldnn_ndims = arr.GetMKLDNNData()->get_desc().data.ndims; if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false; } return true; } -#endif +#endif // MXNET_USE_MKLDNN == 1 static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs, const OpContext& op_ctx, const std::vector& inputs, @@ -274,7 +274,7 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs, MKLDNN_OPCHECK_RUN(ConcatCompute, attrs, op_ctx, inputs, req, outputs); } else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) { FallBackCompute(ConcatCompute, attrs, op_ctx, inputs, req, outputs); -#endif +#endif // MXNET_USE_MKLDNN == 1 } else { LogUnimplementedOp(attrs, op_ctx, inputs, req, outputs); } @@ -294,7 +294,7 @@ static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs, } FallBackCompute(ConcatGradCompute, attrs, ctx, inputs, req, outputs); } -#endif +#endif // MXNET_USE_MKLDNN == 1 struct ConcatGrad { const char *op_name; @@ -306,7 +306,7 @@ struct ConcatGrad { for (size_t i = 0; i < n->inputs.size(); i++) { heads.push_back(n->inputs[i]); } -#endif +#endif // MXNET_USE_MKLDNN == 1 return MakeGradNode(op_name, n, heads, n->attrs.dict); } }; @@ -386,7 +386,7 @@ Example:: return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("TIsMKLDNN", true) -#endif +#endif // MXNET_USE_MKLDNN == 1 CONCAT_FORWARD_ATTRS .set_attr("FInferShape", ConcatShape) .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") @@ -402,13 +402,13 @@ NNVM_REGISTER_OP(_backward_Concat) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) -#endif +#endif // MXNET_USE_MKLDNN == 1 .set_attr("TIsBackward", true) .set_attr("FInferStorageType", BackwardConcatStorageType) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", ConcatGradComputeExCPU) -#endif +#endif // MXNET_USE_MKLDNN == 1 .set_attr("FCompute", ConcatGradCompute); // _rnn_param_concat is a custom concat op with specialized infer_shape, @@ -420,7 +420,7 @@ NNVM_REGISTER_OP(_rnn_param_concat) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) -#endif +#endif // MXNET_USE_MKLDNN == 1 CONCAT_FORWARD_ATTRS .set_attr("FInferShape", RNNParamConcatShape) .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 32ed93e4a463..e31073034594 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -60,7 +60,7 @@ static void ConvolutionComputeExCPU(const nnvm::NodeAttrs& attrs, const ConvolutionParam& params = nnvm::get(attrs.parsed); if (SupportMKLDNNConv(params, inputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); - MKLDNNConvolutionForward(attrs, ctx, inputs, req, outputs); + MKLDNNRun(MKLDNNConvolutionForward, attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(ConvolutionCompute, attrs, ctx, inputs, req, outputs); return; } @@ -75,7 +75,7 @@ static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs, const ConvolutionParam& params = nnvm::get(attrs.parsed); if (SupportMKLDNNConv(params, inputs[0])) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs); + MKLDNNRun(MKLDNNConvolutionBackward, attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(ConvolutionGradCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index 9f461f4e9de3..b61f9ff37002 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -28,13 +28,74 @@ #include "../operator_common.h" #include "../../common/utils.h" #if MXNET_USE_MKLDNN == 1 -#include "./mkldnn/mkldnn_ops-inl.h" #include "./mkldnn/mkldnn_base-inl.h" -#endif +#include "./mkldnn/mkldnn_ops-inl.h" +#endif // MXNET_USE_MKLDNN namespace mxnet { namespace op { +#if MXNET_USE_MKLDNN == 1 +static void DeconvolutionComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const DeconvolutionParam& params = nnvm::get(attrs.parsed); + if (SupportMKLDNNDeconv(params, inputs[0])) { + MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); + MKLDNNRun(MKLDNNDeconvolutionForward, attrs, ctx, inputs, req, outputs); + MKLDNN_OPCHECK_RUN(DeconvolutionCompute, attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(DeconvolutionCompute, attrs, ctx, inputs, req, outputs); +} + +static void DeconvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const DeconvolutionParam& params = nnvm::get(attrs.parsed); + if (SupportMKLDNNDeconv(params, inputs[0])) { + MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); + MKLDNNRun(MKLDNNDeconvolutionBackward, attrs, ctx, inputs, req, outputs); + MKLDNN_OPCHECK_RUN(DeconvolutionGradCompute, attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(DeconvolutionGradCompute, attrs, ctx, inputs, req, outputs); +} + +inline static bool DeconvStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + const DeconvolutionParam& param = nnvm::get(attrs.parsed); + uint32_t in_expected = param.no_bias ? 2 : 3; + CHECK_EQ(in_attrs->size(), in_expected); + CHECK_EQ(out_attrs->size(), 1); + + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, + out_attrs); +} + +inline static bool BackwardDeconvStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const DeconvolutionParam& param = nnvm::get(attrs.parsed); + uint32_t in_expected = param.no_bias ? 3 : 4; + uint32_t out_expected = param.no_bias ? 2 : 3; + CHECK_EQ(in_attrs->size(), in_expected); + CHECK_EQ(out_attrs->size(), out_expected); + + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, + out_attrs); +} +#endif + static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape) { @@ -284,70 +345,6 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs, return true; } -#if MXNET_USE_MKLDNN == 1 -inline static bool DeconvStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - const DeconvolutionParam& param = nnvm::get(attrs.parsed); - uint32_t in_expected = param.no_bias ? 2 : 3; - CHECK_EQ(in_attrs->size(), in_expected); - CHECK_EQ(out_attrs->size(), 1); - - return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, - out_attrs); -} - -inline static bool BackwardDeconvStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - const DeconvolutionParam& param = nnvm::get(attrs.parsed); - uint32_t out_expected = param.no_bias ? 2 : 3; - CHECK_EQ(in_attrs->size(), param.no_bias ? 3U : 4U); - CHECK_EQ(out_attrs->size(), out_expected); - - return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, - out_attrs); -} - -static void DeconvolutionComputeExCPU(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - const DeconvolutionParam& param = nnvm::get(attrs.parsed); - if (SupportMKLDNNDeconv(param, inputs[0])) { - MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); - MKLDNNDeconvolutionForward(attrs, ctx, inputs, req, outputs); - MKLDNN_OPCHECK_RUN(DeconvolutionCompute, attrs, ctx, inputs, req, - outputs); - return; - } - FallBackCompute(DeconvolutionCompute, attrs, ctx, inputs, req, - outputs); -} - -static void DeconvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - const DeconvolutionParam& param = nnvm::get(attrs.parsed); - if (SupportMKLDNNDeconv(param, inputs[0])) { - MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - MKLDNNDeconvolutionBackward(attrs, ctx, inputs, req, outputs); - MKLDNN_OPCHECK_RUN(DeconvolutionGradCompute, attrs, ctx, inputs, req, - outputs); - return; - } - FallBackCompute(DeconvolutionGradCompute, attrs, ctx, inputs, req, - outputs); -} -#endif - static void DeconvolutionParamParser(nnvm::NodeAttrs* attrs) { using namespace mshadow; DeconvolutionParam param_; @@ -430,18 +427,16 @@ NNVM_REGISTER_OP(Deconvolution) }) .set_attr("FInferShape", DeconvolutionShape) .set_attr("FInferType", DeconvolutionType) -#if MXNET_USE_MKLDNN == 1 -.set_attr("FInferStorageType", DeconvStorageType) -#endif .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", DeconvolutionCompute) +.set_attr("FGradient", DeconvolutionGrad{"_backward_Deconvolution"}) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) +.set_attr("FInferStorageType", DeconvStorageType) .set_attr("FComputeEx", DeconvolutionComputeExCPU) #endif -.set_attr("FGradient", DeconvolutionGrad{"_backward_Deconvolution"}) .add_argument("data", "NDArray-or-Symbol", "Input tensor to the deconvolution operation.") .add_argument("weight", "NDArray-or-Symbol", "Weights representing the kernel.") .add_argument("bias", "NDArray-or-Symbol", "Bias added to the result after the deconvolution " @@ -454,15 +449,13 @@ NNVM_REGISTER_OP(_backward_Deconvolution) return params.no_bias ? 2 : 3; }) .set_attr("TIsBackward", true) -#if MXNET_USE_MKLDNN == 1 -.set_attr("FInferStorageType", BackwardDeconvStorageType) -#endif .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) .set_attr_parser(DeconvolutionParamParser) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) +.set_attr("FInferStorageType", BackwardDeconvStorageType) .set_attr("FComputeEx", DeconvolutionGradComputeExCPU) #endif .set_attr("FCompute", DeconvolutionGradCompute); diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 6387dff96eb7..61239d33800c 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -39,7 +39,7 @@ #include "../random/sampler.h" #include "../tensor/elemwise_binary_broadcast_op.h" -#if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) +#if (MSHADOW_USE_MKL == 1) && defined(_OPENMP) && !defined(__CUDACC__) #define MXNET_USE_MKL_DROPOUT 1 #endif diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index fba13f3ece88..1f6d9e313202 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -147,7 +147,10 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - if (SupportMKLDNNFC(inputs[0])) { + // TODO(rongzha1): disable due to flakiness in cpp test IMPERATIVE.FullyConnectedOp + // Will be fixed when we decide to enable the backward of FC. + bool mkldnn_fc_backward_enable = false; + if (mkldnn_fc_backward_enable && SupportMKLDNNFC(inputs[0])) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); MKLDNNFCBackward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute, attrs, ctx, inputs, req, diff --git a/src/operator/nn/mkldnn/mkldnn_act-inl.h b/src/operator/nn/mkldnn/mkldnn_act-inl.h index 9c21b7f70f52..cf3e4f47d1ff 100644 --- a/src/operator/nn/mkldnn/mkldnn_act-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_act-inl.h @@ -20,7 +20,7 @@ /*! * Copyright (c) 2019 by Contributors * \file mkldnn_act-inl.h - * \brief MKLDNN(Quantized) Activation operator based on subgraph + * \brief MKLDNN Activation operator * /author Zhiyuan Huang */ @@ -33,8 +33,6 @@ #include #include "../activation-inl.h" #include "../../leaky_relu-inl.h" -#include "./mkldnn_ops-inl.h" -#include "./mkldnn_base-inl.h" namespace mxnet { namespace op { @@ -51,9 +49,10 @@ struct MKLDNNActParam { mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param); mkldnn::algorithm GetMKLDNNActAlgo(const LeakyReLUParam& param); + mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( const MKLDNNActParam& param, bool is_train, - const mkldnn::memory &input_mem, int dtype); + const mkldnn::memory &input_mem); class MKLDNNActForward { public: @@ -61,14 +60,13 @@ class MKLDNNActForward { MKLDNNActForward(const MKLDNNActParam& param, bool is_train, const NDArray &data, const mkldnn::memory &mem): fwd_pd( - GetActFwdDescImpl(param, is_train, mem, data.dtype())) {} - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output); - const mkldnn::eltwise_forward &GetFwd() const; + GetActFwdDescImpl(param, is_train, mem)) { + fwd_ = std::make_shared(fwd_pd); + } + const inline mkldnn::eltwise_forward &GetFwd() const; private: std::shared_ptr fwd_; - std::shared_ptr data_; - std::shared_ptr out_; }; typedef ParamOpSign MKLDNNActSignature; @@ -80,8 +78,28 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data); void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &in_data, const OpReqType &req, - const NDArray &out_data); + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); + +mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( + const MKLDNNActParam ¶m, const mkldnn::memory &input_mem, + const mkldnn::memory &diff_dst_memory); + +class MKLDNNActBackward { + public: + const mkldnn::eltwise_backward::primitive_desc bwd_pd; + + explicit MKLDNNActBackward(const MKLDNNActParam ¶m, const NDArray &data, + const mkldnn::memory &mem, + const mkldnn::memory &diff_dst_memory): bwd_pd( + GetActBwdDescImpl(param, mem, diff_dst_memory)) { + bwd_prim_ = std::make_shared(bwd_pd); + } + const inline mkldnn::eltwise_backward &GetBwd() const; + + private: + std::shared_ptr bwd_prim_; +}; } // namespace op } // namespace mxnet diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index f221ddf5e345..f3966e6566ce 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -23,6 +23,8 @@ * \author Da Zheng */ +#if MXNET_USE_MKLDNN == 1 + #include #include #include @@ -33,10 +35,7 @@ #include #include "../../operator_common.h" #include "mkldnn_act-inl.h" - -#if MXNET_USE_MKLDNN == 1 - -#include +#include "./mkldnn_base-inl.h" namespace mxnet { namespace op { @@ -107,11 +106,9 @@ mkldnn::algorithm GetMKLDNNActAlgo(const LeakyReLUParam& param) { mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( const MKLDNNActParam& param, bool is_train, - const mkldnn::memory &input_mem, int dtype) { - mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); - mkldnn::memory::desc data_md = data_mpd.desc(); - auto cpu_engine = data_mpd.get_engine(); - + const mkldnn::memory &input_mem) { + mkldnn::memory::desc data_md = input_mem.get_desc(); + auto cpu_engine = CpuEngine::Get()->get_engine(); auto alg = param.alg; auto prop = is_train ? mkldnn::prop_kind::forward_training : @@ -120,28 +117,7 @@ mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); } -void MKLDNNActForward::SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { - if (this->data_ == nullptr) - this->data_ = std::make_shared(data.get_primitive_desc(), - data.get_data_handle()); - else - this->data_->set_data_handle(data.get_data_handle()); - - CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc()); - if (this->out_ == nullptr) - this->out_ = std::make_shared(fwd_pd.dst_primitive_desc(), - output.get_data_handle()); - else - this->out_->set_data_handle(output.get_data_handle()); - - if (this->fwd_ == nullptr) { - this->fwd_ = std::shared_ptr( - new mkldnn::eltwise_forward(fwd_pd, mkldnn::primitive::at(*this->data_), - *this->out_)); - } -} - -const mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const { +const inline mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const { return *fwd_; } @@ -155,10 +131,9 @@ MKLDNNActForward &GetActForward(const MKLDNNActParam& param, #endif MKLDNNActSignature key(param); key.AddSign(ctx.is_train); - key.AddSign(param.alg); + key.AddSign(static_cast(param.alg)); key.AddSign(param.slope); key.AddSign(in_data); - auto it = fwds.find(key); if (it == fwds.end()) { MKLDNNActForward fwd(param, ctx.is_train, in_data, in_mem); @@ -182,9 +157,9 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, auto input_mem = in_buffer.GetMKLDNNData(); MKLDNNActForward &fwd = GetActForward(param_, ctx, in_buffer, *input_mem); - auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer); - fwd.SetNewMem(*input_mem, *out_mem_t.second); - stream->RegisterPrim(fwd.GetFwd()); + auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, &in_buffer); + stream->RegisterPrimArgs(fwd.GetFwd(), + {{ MKLDNN_ARG_SRC, *input_mem}, { MKLDNN_ARG_DST, *out_mem_t.second}}); CommitOutput(out_data, out_mem_t); stream->Submit(); } @@ -205,32 +180,21 @@ void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, auto input_mem = in_buffer.GetMKLDNNData(); MKLDNNActForward &fwd = GetActForward(param_, ctx, in_buffer, *input_mem); - auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer); - fwd.SetNewMem(*input_mem, *out_mem_t.second); - stream->RegisterPrim(fwd.GetFwd()); + auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, &in_buffer); + stream->RegisterPrimArgs(fwd.GetFwd(), + {{ MKLDNN_ARG_SRC, *input_mem}, { MKLDNN_ARG_DST, *out_mem_t.second}}); CommitOutput(out_data, out_mem_t); stream->Submit(); } -static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( +mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( const MKLDNNActParam ¶m, const mkldnn::memory &input_mem, - const mkldnn::memory &diff_dst_memory, int dtype) { - mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); - mkldnn::memory::desc data_md = data_mpd.desc(); - mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc(); - auto cpu_engine = data_mpd.get_engine(); + const mkldnn::memory &diff_dst_memory) { + mkldnn::memory::desc data_md = input_mem.get_desc(); + mkldnn::memory::desc diff_md = diff_dst_memory.get_desc(); + auto cpu_engine = CpuEngine::Get()->get_engine(); auto alg = param.alg; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, - alg, data_md, param.slope); - mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); - mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, param.slope); - mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, - fw_pdesc); - return bw_pdesc; - }); - LOG(FATAL) << "Unsupported data type for MKLDNN activation"; mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, alg, data_md, param.slope); mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); @@ -240,45 +204,9 @@ static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( return bw_pdesc; } -class MKLDNNActBackward { - std::shared_ptr bwd; - std::shared_ptr data; - std::shared_ptr diff_dst_memory; - std::shared_ptr diff_src_memory; - - public: - const mkldnn::eltwise_backward::primitive_desc pd; - - explicit MKLDNNActBackward(const MKLDNNActParam ¶m, const NDArray &data, - const mkldnn::memory &mem, - const mkldnn::memory &diff_dst_memory) - : pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {} - - void SetNewMem(const mkldnn::memory &data, - const mkldnn::memory &diff_dst_memory, - const mkldnn::memory &diff_src_memory) { - if (this->bwd != nullptr) { - this->data->set_data_handle(data.get_data_handle()); - this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle()); - this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle()); - } else { - this->data = std::shared_ptr(new mkldnn::memory( - data.get_primitive_desc(), data.get_data_handle())); - this->diff_dst_memory = std::shared_ptr( - new mkldnn::memory(diff_dst_memory.get_primitive_desc(), - diff_dst_memory.get_data_handle())); - this->diff_src_memory = std::shared_ptr( - new mkldnn::memory(diff_src_memory.get_primitive_desc(), - diff_src_memory.get_data_handle())); - this->bwd = std::shared_ptr( - new mkldnn::eltwise_backward( - this->pd, mkldnn::primitive::at(*this->data), - *this->diff_dst_memory, *this->diff_src_memory)); - } - } - - const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; } -}; +const inline mkldnn::eltwise_backward &MKLDNNActBackward::GetBwd() const { + return *bwd_prim_; +} static inline MKLDNNActBackward &GetActBackward(const MKLDNNActParam ¶m, const OpContext &ctx, @@ -327,15 +255,19 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx auto input_mem = in_buffer.GetMKLDNNData(); // We need to make sure the two inputs to eltwise_backward has the same memory // descriptor. Otherwise, the perf will suffer. - if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc()) - input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc()); + if (input_mem->get_desc() != diff_dst_memory->get_desc()) + input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc()); MKLDNNActBackward &bwd = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem); MKLDNNStream *stream = MKLDNNStream::Get(); mkldnn_output_t diff_src_memory = - CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); - bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second); - stream->RegisterPrim(bwd.GetBwd()); + CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req); + mkldnn_args_map_t args = { + { MKLDNN_ARG_SRC, *input_mem }, + { MKLDNN_ARG_DIFF_DST, *diff_dst_memory }, + { MKLDNN_ARG_DIFF_SRC, *diff_src_memory.second }, + }; + stream->RegisterPrimArgs(bwd.GetBwd(), args); CommitOutput(in_grad, diff_src_memory); stream->Submit(); } @@ -367,20 +299,23 @@ void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs, auto input_mem = in_buffer.GetMKLDNNData(); // We need to make sure the two inputs to eltwise_backward has the same memory // descriptor. Otherwise, the perf will suffer. - if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc()) - input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc()); + if (input_mem->get_desc() != diff_dst_memory->get_desc()) + input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc()); MKLDNNActBackward &bwd = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem); MKLDNNStream *stream = MKLDNNStream::Get(); mkldnn_output_t diff_src_memory = - CreateMKLDNNMem(output, bwd.pd.diff_src_primitive_desc(), req); - bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second); - stream->RegisterPrim(bwd.GetBwd()); + CreateMKLDNNMem(output, bwd.bwd_pd.diff_src_desc(), req); + mkldnn_args_map_t args = { + { MKLDNN_ARG_SRC, *input_mem }, + { MKLDNN_ARG_DIFF_DST, *diff_dst_memory }, + { MKLDNN_ARG_DIFF_SRC, *diff_src_memory.second }, + }; + stream->RegisterPrimArgs(bwd.GetBwd(), args); CommitOutput(output, diff_src_memory); stream->Submit(); } } // namespace op } // namespace mxnet - #endif diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 961aa8b05a84..0f371d174e40 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -59,7 +59,7 @@ #include "mxnet/ndarray.h" #include "mxnet/op_attr_types.h" #include "mxnet/resource.h" -using namespace mkldnn; + namespace mxnet { // ===== CpuEngine ======================================= @@ -80,7 +80,7 @@ class CpuEngine { mkldnn::engine &get_engine() { return _cpu_engine; } protected: - CpuEngine() : _cpu_engine(mkldnn::engine::cpu, 0) {} + CpuEngine() : _cpu_engine(mkldnn::engine::kind::cpu, 0) {} ~CpuEngine() {} private: @@ -93,27 +93,22 @@ struct data_type_enum {}; template <> struct data_type_enum { - enum { type = mkldnn::memory::data_type::f32 }; + enum { type = static_cast(mkldnn::memory::data_type::f32) }; }; template <> struct data_type_enum { - enum { type = mkldnn::memory::data_type::s32 }; -}; - -template <> -struct data_type_enum { - enum { type = mkldnn::memory::data_type::s16 }; + enum { type = static_cast(mkldnn::memory::data_type::s32) }; }; template <> struct data_type_enum { - enum { type = mkldnn::memory::data_type::s8 }; + enum { type = static_cast(mkldnn::memory::data_type::s8) }; }; template <> struct data_type_enum { - enum { type = mkldnn::memory::data_type::u8 }; + enum { type = static_cast(mkldnn::memory::data_type::u8) }; }; static inline bool SupportMKLDNNArray(int dtype, const mxnet::TShape &shape) { @@ -206,7 +201,7 @@ static int GetTypeSize(int dtype) { static inline size_t GetArraySize(const NDArray &arr) { if (arr.IsMKLDNNData()) { - return arr.GetMKLDNNData()->get_primitive_desc().get_size(); + return arr.GetMKLDNNData()->get_desc().get_size(); } return arr.shape().Size() * GetTypeSize(arr.dtype()); } @@ -223,10 +218,25 @@ static inline mkldnn::memory::data_type get_mkldnn_type(int dtype) { return mkldnn::memory::data_type::u8; default: LOG(FATAL) << "unknown type for MKLDNN"; - return mkldnn::memory::data_type::data_undef; + return mkldnn::memory::data_type::undef; } } +template +static inline mkldnn::memory::data_type get_mkldnn_type() { + return static_cast(data_type_enum::type); +} + +static inline mkldnn_data_type_t get_mkldnn_type_t(int dtype) { + return static_cast(get_mkldnn_type(dtype)); +} + +template +static inline mkldnn_data_type_t get_mkldnn_type_t() { + return static_cast(data_type_enum::type); +} + + static inline int get_mxnet_type(mkldnn_data_type_t dtype) { auto mkldnn_dtype = static_cast(dtype); switch (mkldnn_dtype) { @@ -261,7 +271,21 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int dtype = -1 mkldnn::memory::dims dims(ndim); dtype = (dtype == -1) ? arr.dtype() : dtype; for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; - return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format::any}; + return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any}; +} + +inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray &arr, int dtype = -1) { + int ndim = arr.shape().ndim(); + mkldnn::memory::dims dims(ndim); + dtype = (dtype == -1) ? arr.dtype() : dtype; + for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; + auto format = mkldnn::memory::format_tag::any; + // for batch 256 alexnet benchmark test + if (dims.size() == 2) { + format = mkldnn::memory::format_tag::ab; + } + + return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), format}; } inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr, @@ -286,8 +310,17 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr, static_cast(arr.shape()[C]), static_cast(arr.shape()[H]), static_cast(arr.shape()[W])}; } - return mkldnn::memory::desc{tz, get_mkldnn_type(dtype), mkldnn::memory::format::any}; + return mkldnn::memory::desc{tz, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any}; + } +} + +inline static bool CheckMKLDNNInputArrayIsView(const std::vector &inputs) { + for (const auto &in : inputs) { + if (in.IsView() && in.IsMKLDNNData()) { + return true; + } } + return false; } typedef std::shared_ptr mkldnn_mem_ptr; @@ -352,19 +385,24 @@ class TmpMemMgr { this->est_size = 0; } - mkldnn::memory *Alloc(const mkldnn::memory::primitive_desc &pd); + mkldnn::memory *Alloc(const mkldnn::memory::desc &md); }; +typedef std::unordered_map mkldnn_args_map_t; class MKLDNNStream { - std::vector net; + std::vector > net_prim_args; // Here we hold all memory related to the operators in the stream. std::vector > mem_holder; + mkldnn::stream s; public: static MKLDNNStream *Get(); - void RegisterPrim(const mkldnn::primitive &prim) { - net.push_back(prim); + MKLDNNStream(): s(CpuEngine::Get()->get_engine()) {} + + void RegisterPrimArgs(const mkldnn::primitive &prim, + const mkldnn_args_map_t &args) { + net_prim_args.emplace_back(prim, args); } void RegisterMem(std::shared_ptr mem) { @@ -372,7 +410,7 @@ class MKLDNNStream { } bool HasOps() const { - return !net.empty(); + return !net_prim_args.empty(); } /* @@ -381,9 +419,11 @@ class MKLDNNStream { * might want to separate mkldnn execution and memory cleanup. */ void Submit(bool cleanup = true) { - if (!net.empty()) { - mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); - net.clear(); + if (!net_prim_args.empty()) { + for (auto &v : net_prim_args) { + v.first.execute(s, v.second); + } + net_prim_args.clear(); } if (cleanup) Cleanup(); @@ -405,18 +445,18 @@ typedef std::pair mkldnn_output_t; void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem); /* - * Here we want to get MKLDNN memory whose primitive desc is exactly the same as + * Here we want to get MKLDNN memory whose desc is exactly the same as * the given one. operator== can't guarantee that. == can return true even if * the formats are different. I need to double check its format. */ static inline mkldnn::memory *GetMKLDNNExact( - const mkldnn::memory *mem, mkldnn::memory::primitive_desc desc) { - mkldnn::memory::primitive_desc src_desc = mem->get_primitive_desc(); - if (desc == src_desc && desc.desc().data.format == src_desc.desc().data.format) { + const mkldnn::memory *mem, const mkldnn::memory::desc &desc) { + mkldnn::memory::desc src_desc = mem->get_desc(); + if (desc == src_desc) { return const_cast(mem); } else { std::shared_ptr ret(new mkldnn::memory( - desc, mem->get_data_handle())); + desc, CpuEngine::Get()->get_engine(), mem->get_data_handle())); MKLDNNStream::Get()->RegisterMem(ret); return ret.get(); } @@ -434,10 +474,10 @@ static inline mkldnn::memory *GetMKLDNNExact( * the output back to the output NDArray. */ mkldnn_output_t CreateMKLDNNMem(const NDArray &out_arr, - const mkldnn::memory::primitive_desc &desc, + const mkldnn::memory::desc &desc, OpReqType req, const NDArray* in_arr = nullptr); mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray &out_arr, - const mkldnn::memory::primitive_desc &desc, + const mkldnn::memory::desc &desc, OpReqType req); /* This function has to be used with one of the functions above. */ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res); @@ -466,13 +506,15 @@ static inline void CreateDefaultInputs(const std::vector &arrs, const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups); const mkldnn::memory *GetWeights(const NDArray &arr, - const mkldnn::memory::primitive_desc &target_pd, + const mkldnn::memory::desc &target_md, int num_groups); -mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc); -mkldnn_memory_format_t GetDefaultFormat(int num_dims); -mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd, - mkldnn_memory_format_t format); +bool IsDefaultFormat(const mkldnn::memory::desc &desc); +bool IsMKLDNN(const mkldnn::memory::desc &desc); + +mkldnn_format_tag_t GetDefaultFormat(const mkldnn::memory::desc &md); +mkldnn_format_tag_t GetDefaultFormat(int num_dims); +mkldnn::memory::desc GetDesc(const mkldnn::memory::desc &md, const mkldnn_format_tag_t &format); inline bool same_shape(const mxnet::TShape &shape, const mkldnn_dims_t dims, int ndims) { if (shape.ndim() != ndims) @@ -500,7 +542,7 @@ inline bool same_shape(const mxnet::TShape &shape, int dtype, } /* - * There is a large overhead of getting mkldnn::memory::primitive_desc from + * There is a large overhead of getting mkldnn::memory::desc from * mkldnn::memory. This class is created to cache the metadata of mkldnn memory * to provide a much more lightweight method to access them. */ @@ -510,16 +552,15 @@ class MKLDNNMemory { size_t size; // The number of bytes. public: - MKLDNNMemory(mkldnn::memory::primitive_desc pd, void *addr): desc(pd.desc()) { - mem.reset(new mkldnn::memory(pd, addr)); - size = pd.get_size(); + MKLDNNMemory(mkldnn::memory::desc md, void *addr): desc(md) { + mem.reset(new mkldnn::memory(md, CpuEngine::Get()->get_engine(), addr)); + size = desc.get_size(); } explicit MKLDNNMemory(std::shared_ptr mem): desc( - mem->get_primitive_desc().desc()) { + mem->get_desc()) { this->mem = mem; - mkldnn::memory::primitive_desc pd = mem->get_primitive_desc(); - size = pd.get_size(); + size = desc.get_size(); } void SetDataHandle(void *handle) { @@ -542,28 +583,29 @@ class MKLDNNMemory { return size; } - mkldnn::memory::primitive_desc GetPrimitiveDesc() const { - return mem->get_primitive_desc(); + mkldnn::memory::desc GetDesc() const { + return mem->get_desc(); } - mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn_memory_format_t format) const { - return mxnet::GetPrimitiveDesc(mem->get_primitive_desc(), format); + mkldnn::memory::desc GetDesc(mkldnn_format_tag_t format) const { + mkldnn::memory::dims dims(desc.data.dims, desc.data.dims + desc.data.ndims); + mkldnn::memory::data_type cpp_type = + static_cast(desc.data.data_type); + mkldnn::memory::desc data_md(dims, cpp_type, + static_cast(format)); + return data_md; } - mkldnn_memory_format_t GetDefaultFormat() const { + mkldnn_format_tag_t GetDefaultFormat() const { return mxnet::GetDefaultFormat(desc); } - mkldnn_memory_format_t GetFormat() const { - return desc.data.format; - } - bool IsMKLDNN() const { - return GetFormat() != GetDefaultFormat(); + return mxnet::IsMKLDNN(desc); } - bool SameFormat(mkldnn::memory::primitive_desc pd) const { - return mem->get_primitive_desc() == pd; + bool SameFormat(mkldnn::memory::desc md) const { + return mem->get_desc() == md; } bool SameFormat(const mxnet::TShape &shape, int dtype) const { @@ -571,9 +613,8 @@ class MKLDNNMemory { } void ReorderTo(mkldnn::memory *other) const { - std::vector net; - net.push_back(mkldnn::reorder(*mem, *other)); - mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + mkldnn::stream s(CpuEngine::Get()->get_engine()); + mkldnn::reorder(*mem, *other).execute(s, *mem, *other); } }; @@ -630,12 +671,19 @@ bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs, if (debug) check.CopyResult(outputs, indice); struct MKLDNNPostEltwiseParam { - mkldnn::algorithm alg = mkldnn::algorithm::algorithm_undef; + mkldnn::algorithm alg = mkldnn::algorithm::undef; float scale = 1.f; float alpha = 0.f; float beta = 1.f; }; +void MKLDNNRun(mxnet::FComputeEx fn, + const nnvm::NodeAttrs &attrs, + const mxnet::OpContext &ctx, + const std::vector &inputs_, + const std::vector &req, + const std::vector &outputs_); + } // namespace mxnet #endif #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index fca908fc8e39..1b147c69ba62 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -54,27 +54,32 @@ void *AlignMem(void *mem, size_t size, size_t alignment, size_t *space) { return reinterpret_cast(addr); } -mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::primitive_desc &pd) { +mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::desc &md) { // We need to include the size of the memory used for alignment. - this->est_size += pd.get_size() + alignment; - void *mem = AlignMem(this->curr_mem, pd.get_size(), alignment, &this->curr_size); + this->est_size += md.get_size() + alignment; + void *mem = AlignMem(this->curr_mem, md.get_size(), alignment, &this->curr_size); if (mem) { // The memory is allocated from the temporary memory space in the // operator. It'll only become invalid after we exit from the operator. - mkldnn_mem_ptr ret(new mkldnn::memory(pd, mem)); + mkldnn_mem_ptr ret(new mkldnn::memory(md, CpuEngine::Get()->get_engine(), mem)); MKLDNNStream::Get()->RegisterMem(ret); CHECK_EQ(mem, mem); - this->curr_size -= pd.get_size(); - this->curr_mem = static_cast(mem) + pd.get_size(); + this->curr_size -= md.get_size(); + this->curr_mem = static_cast(mem) + md.get_size(); return ret.get(); } else { - // If curr_mem has been initialized and we still reach here. It means - // the current allocated memory isn't enough. + // If curr_mem has been initialized and we still reach here, it means the current + // allocated memory isn't enough. But it doesn't matter for multiple invokes of a + // operator, as the TmpMemMgr could estimate the space at the first iteration and + // then re-requests abundant space from MXNet resource. MKL-DNN could allocate + // the space by itself. Thus, we just let it continue for estimating the maximum + // required space size. It will be allocated at next call. if (this->curr_mem && dmlc::GetEnv("MXNET_MKLDNN_DEBUG", false)) { - LOG(WARNING) << "Allocate " << pd.get_size() - << " bytes with malloc directly"; + LOG(WARNING) << "mkl-dnn debug message: The rest of the temporary space is not " + << "adequate for allocating " << md.get_size() << " bytes. Thus, mkl-dnn " + << "allocate the space by itself."; } - mkldnn_mem_ptr ret(new mkldnn::memory(pd)); + mkldnn_mem_ptr ret(new mkldnn::memory(md, CpuEngine::Get()->get_engine())); MKLDNNStream::Get()->RegisterMem(ret); return ret.get(); } @@ -82,85 +87,88 @@ mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::primitive_desc &pd) { void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem) { MKLDNNStream *stream = MKLDNNStream::Get(); + mkldnn::memory::desc from_desc = mem.get_desc(); + mkldnn::memory::desc this_desc = this_mem->get_desc(); + mkldnn_format_tag_t from_def_format = GetDefaultFormat(from_desc); + mkldnn_format_tag_t this_def_format = GetDefaultFormat(this_desc); - mkldnn::memory::primitive_desc from_pd = mem.get_primitive_desc(); - mkldnn::memory::desc from_desc = from_pd.desc(); - mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc(); - mkldnn::memory::desc this_desc = this_pd.desc(); - mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc); - mkldnn_memory_format_t this_def_format = GetDefaultFormat(this_desc); - // It's possible that the memory and the NDArray don't have the same shape. - if (!same_shape(this_desc, from_desc) - // If the source memory uses the default layout, we can reshape directly. - && from_def_format == from_desc.data.format) { + if (!same_shape(this_desc, from_desc) && IsDefaultFormat(from_desc)) { // In this case, we can simply create a new MKLDNN memory for the required // shape. mkldnn::memory::dims dims(this_desc.data.dims, this_desc.data.dims + this_desc.data.ndims); auto this_dtype = static_cast(this_desc.data.data_type); - auto this_format = static_cast(GetDefaultFormat(this_desc)); - mkldnn::memory::desc data_md(dims, this_dtype, this_format); - mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine()); - mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle())); + mkldnn::memory::desc data_md(dims, this_dtype, + static_cast(this_def_format)); + + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(data_md, mem.get_engine(), mem.get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem)); + std::unordered_map args({{MKLDNN_ARG_FROM, *tmp_mem}, + {MKLDNN_ARG_TO, *this_mem}}); + stream->RegisterPrimArgs(mkldnn::reorder(*tmp_mem, *this_mem), args); } else if (!same_shape(this_desc, from_desc)) { // In this case, the source memory stores data in a customized layout. We // need to reorganize the data in memory before we can reshape. - mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(from_pd, from_def_format); - mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd); - stream->RegisterPrim(mkldnn::reorder(mem, *def_mem)); + mkldnn::memory::desc def_desc = GetDesc(from_desc, from_def_format); + mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_desc); + std::unordered_map args({{MKLDNN_ARG_FROM, mem}, + {MKLDNN_ARG_TO, *def_mem}}); + stream->RegisterPrimArgs(mkldnn::reorder(mem, *def_mem), args); + // Now we can reshape it - mkldnn::memory::dims dims(this_desc.data.dims, - this_desc.data.dims + this_desc.data.ndims); - auto this_dtype = static_cast(this_desc.data.data_type); - auto this_format = static_cast(GetDefaultFormat(this_desc)); - mkldnn::memory::desc data_md(dims, this_dtype, this_format); - mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine()); - mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle())); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(this_desc, + mem.get_engine(), def_mem->get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem)); - } else if (from_pd == this_pd) { + args = {{MKLDNN_ARG_FROM, *tmp_mem}, {MKLDNN_ARG_TO, *this_mem}}; + stream->RegisterPrimArgs(mkldnn::reorder(*tmp_mem, *this_mem), args); +} else if (this_desc == from_desc) { + std::unordered_map args({{MKLDNN_ARG_FROM, mem}, + {MKLDNN_ARG_TO, *this_mem}}); // If the layout is the same, we can just copy data. - stream->RegisterPrim(mkldnn::reorder(mem, *this_mem)); - } else { + stream->RegisterPrimArgs(mkldnn::reorder(mem, *this_mem), args); +} else { // If both are not using the default layouts. There isn't much we can do, // other than reorder data layout directly. - if (this_def_format != this_desc.data.format - && from_def_format != from_desc.data.format) { - stream->RegisterPrim(mkldnn::reorder(mem, *this_mem)); - } else if (this_def_format == this_desc.data.format) { + if (!IsDefaultFormat(this_desc) && !IsDefaultFormat(from_desc)) { + std::unordered_map args({{MKLDNN_ARG_FROM, mem}, + {MKLDNN_ARG_TO, *this_mem}}); + stream->RegisterPrimArgs(mkldnn::reorder(mem, *this_mem), args); + } else if (IsDefaultFormat(this_desc)) { // If the dest mem uses the default memory layout, we can simply use // the default format of the source memory to improve perf of reorder. - mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(from_pd, - from_def_format); - mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, this_mem->get_data_handle())); + mkldnn::memory::desc desc = GetDesc(from_desc, from_def_format); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(desc, + mem.get_engine(), this_mem->get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem)); + std::unordered_map args({{MKLDNN_ARG_FROM, mem}, + {MKLDNN_ARG_TO, *tmp_mem}}); + stream->RegisterPrimArgs(mkldnn::reorder(mem, *tmp_mem), args); } else { // If the src mem uses the default memory layout, we can use // the default format of the source memory to improve perf. - mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(this_pd, - this_def_format); - mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle())); + mkldnn::memory::desc desc = GetDesc(this_desc, this_def_format); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(desc, + this_mem->get_engine(), mem.get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem)); + std::unordered_map args({{MKLDNN_ARG_FROM, *tmp_mem}, + {MKLDNN_ARG_TO, *this_mem}}); + stream->RegisterPrimArgs(mkldnn::reorder(*tmp_mem, *this_mem), args); } } } bool CanWriteTo(const NDArray &out_arr, const NDArray &in_arr, - const mkldnn::memory::primitive_desc &desc) { + const mkldnn::memory::desc &desc) { auto in_mem = in_arr.GetMKLDNNData(); bool add_same = in_mem->get_data_handle() == out_arr.GetMKLDNNData()->get_data_handle(); - bool pdesc_same = out_arr.GetMKLDNNData()->get_primitive_desc() == desc && - in_mem->get_primitive_desc() == desc; + bool pdesc_same = out_arr.GetMKLDNNData()->get_desc() == desc && + in_mem->get_desc() == desc; return add_same && pdesc_same; } mkldnn_output_t CreateMKLDNNMem(const NDArray &out_arr, - const mkldnn::memory::primitive_desc &desc, + const mkldnn::memory::desc &desc, OpReqType req, const NDArray* in_arr) { if (kAddTo == req) { @@ -188,7 +196,7 @@ mkldnn_output_t CreateMKLDNNMem(const NDArray &out_arr, } mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray &out_arr, - const mkldnn::memory::primitive_desc &desc, + const mkldnn::memory::desc &desc, OpReqType req) { if (kAddTo == req) { auto tmp = TmpMemMgr::Get()->Alloc(desc); @@ -197,10 +205,8 @@ mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray &out_arr, auto tmp = TmpMemMgr::Get()->Alloc(desc); return mkldnn_output_t(OutDataOp::CopyBack, tmp); } else { - auto _desc = desc; - auto def_format = GetDefaultFormat(_desc.desc()); mkldnn::memory *mem = nullptr; - if (def_format == _desc.desc().data.format) { + if (IsDefaultFormat(desc)) { mem = const_cast(out_arr).CreateMKLDNNData(desc); } if (mem == nullptr) { @@ -217,8 +223,8 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) { const_cast(arr).CopyFrom(*res.second); } else if (res.first == AddBack) { auto res_memory = res.second; - auto target_pd = arr.GetMKLDNNData()->get_primitive_desc(); - auto mem = arr.GetMKLDNNData(res.second->get_primitive_desc()); + auto target_pd = arr.GetMKLDNNData()->get_desc(); + auto mem = arr.GetMKLDNNData(res.second->get_desc()); if (mem == nullptr) { auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd); MKLDNNCopy(*res_memory, tmp_memory); @@ -232,12 +238,12 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) { const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) { const auto type = get_mkldnn_type(arr.dtype()); auto tz = mkldnn::memory::dims{0}; - auto format = mkldnn::memory::format::format_undef; + auto format_tag = mkldnn::memory::format_tag::undef; auto engine = CpuEngine::Get()->get_engine(); const int O = 0, I = 1, H = 2, W = 3; if (arr.shape().ndim() == 2) { tz = mkldnn::memory::dims{static_cast(arr.shape()[O]), static_cast(arr.shape()[I])}; - format = mkldnn::memory::format::oi; + format_tag = mkldnn::memory::format_tag::oi; } else if (arr.shape().ndim() == 3) { tz = num_groups > 1 ? mkldnn::memory::dims{num_groups, static_cast(arr.shape()[O] / num_groups), @@ -246,7 +252,8 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) { : mkldnn::memory::dims{static_cast(arr.shape()[O]), static_cast(arr.shape()[I]), static_cast(arr.shape()[H])}; - format = num_groups > 1 ? mkldnn::memory::format::goiw : mkldnn::memory::format::oiw; + format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goiw + : mkldnn::memory::format_tag::oiw; } else if (arr.shape().ndim() == 4) { tz = num_groups > 1 ? mkldnn::memory::dims{num_groups, static_cast(arr.shape()[O] / num_groups), @@ -256,168 +263,100 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) { : mkldnn::memory::dims{ static_cast(arr.shape()[O]), static_cast(arr.shape()[I]), static_cast(arr.shape()[H]), static_cast(arr.shape()[W])}; - format = num_groups > 1 ? mkldnn::memory::format::goihw : mkldnn::memory::format::oihw; + format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goihw + : mkldnn::memory::format_tag::oihw; } else { LOG(FATAL) << "The weight array has an unsupported number of dimensions"; } - const auto md = mkldnn::memory::desc{tz, type, format}; - const auto pd = mkldnn::memory::primitive_desc{md, engine}; - return arr.GetMKLDNNData(pd); + const auto md = mkldnn::memory::desc{tz, type, format_tag}; + return arr.GetMKLDNNData(md); } const mkldnn::memory *GetWeights(const NDArray &arr, - const mkldnn::memory::primitive_desc &target_pd, int num_groups) { - const mkldnn::memory *mem = arr.GetMKLDNNData(target_pd); + const mkldnn::memory::desc &target_desc, int num_groups) { + const mkldnn::memory *mem = arr.GetMKLDNNData(target_desc); // If the weight array already uses the target layout, simply return it directly. if (mem) return mem; mem = GetWeights(arr, num_groups); - if (mem == nullptr) mem = arr.GetMKLDNNDataReorder(target_pd); - if (mem->get_primitive_desc() == target_pd) return mem; + if (mem == nullptr) mem = arr.GetMKLDNNDataReorder(target_desc); + if (mem->get_desc() == target_desc) return mem; - auto ret = TmpMemMgr::Get()->Alloc(target_pd); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(*mem, *ret)); + auto ret = TmpMemMgr::Get()->Alloc(target_desc); + std::unordered_map args({{MKLDNN_ARG_FROM, *mem}, + {MKLDNN_ARG_TO, *ret}}); + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(*mem, *ret), args); return ret; } -mkldnn_memory_format_t GetDefaultFormat(int num_dims) { + +// default: block and dims' stride increase monotonically +// mkldnn: 1.winograd 2.rnn packed 3. block and dims'stride is not increase monotonically +bool IsMKLDNN(const mkldnn::memory::desc &desc) { + bool rslt = true; + if (desc.data.format_kind == mkldnn_blocked) { + if (desc.data.format_desc.blocking.inner_nblks == 0) { + int i = 0; + for (i = 0; i < desc.data.ndims-1; i++) { + if (desc.data.format_desc.blocking.strides[i] + < desc.data.format_desc.blocking.strides[i + 1]) { + break; + } + } + if (i == desc.data.ndims-1) { + rslt = false; + } + } + } + return rslt; +} + +mkldnn_format_tag_t GetDefaultFormat(int num_dims) { switch (num_dims) { - case 1: return mkldnn_x; - case 2: return mkldnn_nc; - case 3: return mkldnn_ncw; - case 4: return mkldnn_nchw; - case 5: return mkldnn_goihw; + case 1: return mkldnn_a; + case 2: return mkldnn_ab; + case 3: return mkldnn_abc; + case 4: return mkldnn_abcd; + case 5: return mkldnn_abcde; + case 6: return mkldnn_abcdef; default: - LOG(FATAL) << "Unsupported MKLDNN dimensions: " << num_dims; - return mkldnn_format_undef; + LOG(FATAL) << "Not implemented dimension (" << num_dims << ") for MKLDNN"; + return mkldnn_format_tag_undef; } } -mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) { - if (desc.data.ndims == 1) { - return desc.data.format; - } else if (desc.data.ndims == 2) { - if (desc.data.format == mkldnn_io) - return mkldnn_oi; - else - return desc.data.format; - } else if (desc.data.ndims == 3) { - switch (desc.data.format) { - case mkldnn_ncw: - case mkldnn_nwc: - case mkldnn_nCw8c: - case mkldnn_nCw16c: - return mkldnn_ncw; - case mkldnn_oiw: - case mkldnn_wio: - case mkldnn_Owi8o: - case mkldnn_OIw8i8o: - case mkldnn_OIw8o8i: - case mkldnn_OIw16i16o: - case mkldnn_OIw16o16i: - case mkldnn_Oiw16o: - case mkldnn_Owi16o: - case mkldnn_OIw8i16o2i: - case mkldnn_OIw8o16i2o: - case mkldnn_IOw16o16i: - return mkldnn_oiw; - default: - LOG(FATAL) << "Unknown MKLDNN format for 3 dimensions: " << desc.data.format; - return mkldnn_format_undef; - } - } else if (desc.data.ndims == 4) { - switch (desc.data.format) { - case mkldnn_nchw: - case mkldnn_nhwc: - case mkldnn_chwn: - case mkldnn_nChw4c: - case mkldnn_nChw8c: - case mkldnn_nChw16c: - return mkldnn_nchw; - case mkldnn_oihw: - case mkldnn_ihwo: - case mkldnn_hwio: - case mkldnn_iohw: - case mkldnn_oIhw8i: - case mkldnn_oIhw16i: - case mkldnn_OIhw4i4o: - case mkldnn_OIhw8i8o: - case mkldnn_hwio_s8s8: - case mkldnn_OIhw16i16o: - case mkldnn_OIhw4i16o4i: - case mkldnn_OIhw4i16o4i_s8s8: - case mkldnn_OIhw8i16o2i: - case mkldnn_OIhw8o16i2o: - case mkldnn_OIhw8o8i: - case mkldnn_OIhw16o16i: - case mkldnn_IOhw16o16i: - case mkldnn_Oihw8o: - case mkldnn_Oihw16o: - case mkldnn_Ohwi8o: - case mkldnn_Ohwi16o: - case mkldnn_OhIw16o4i: - return mkldnn_oihw; - case mkldnn_goiw: - case mkldnn_gOwi8o: - case mkldnn_gOIw8o8i: - case mkldnn_gOIw8i8o: - case mkldnn_gOIw16i16o: - case mkldnn_gOIw16o16i: - case mkldnn_gOiw16o: - case mkldnn_gOwi16o: - case mkldnn_gOIw8i16o2i: - case mkldnn_gOIw8o16i2o: - case mkldnn_gIOw16o16i: - return mkldnn_goiw; - default: - LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format; - return mkldnn_format_undef; - } - } else if (desc.data.ndims == 5) { - switch (desc.data.format) { - case mkldnn_goihw: - case mkldnn_giohw: - case mkldnn_hwigo: - case mkldnn_hwigo_s8s8: - case mkldnn_gOIhw4i4o: - case mkldnn_gOIhw8i8o: - case mkldnn_gOIhw16i16o: - case mkldnn_gOIhw4i16o4i: - case mkldnn_gOIhw4i16o4i_s8s8: - case mkldnn_gOIhw8i16o2i: - case mkldnn_gOIhw8o16i2o: - case mkldnn_gOIhw8o8i: - case mkldnn_gOIhw4o4i: - case mkldnn_gOIhw16o16i: - case mkldnn_gIOhw16o16i: - case mkldnn_gOihw8o: - case mkldnn_Goihw8g: - case mkldnn_gOihw16o: - case mkldnn_Goihw16g: - case mkldnn_gOhwi8o: - case mkldnn_gOhwi16o: - case mkldnn_gOhIw16o4i: - case mkldnn_Goihw16g_s8s8: - return mkldnn_goihw; - default: - LOG(FATAL) << "Unknown MKLDNN format for 5 dimensions: " << desc.data.format; - return mkldnn_format_undef; +mkldnn_format_tag_t GetDefaultFormat(const mkldnn::memory::desc &desc) { + return GetDefaultFormat(desc.data.ndims); +} + +bool IsDefaultFormat(const mkldnn::memory::desc &desc) { + bool rslt = false; + if (desc.data.format_kind == mkldnn_blocked) { + if (desc.data.format_desc.blocking.inner_nblks == 0) { + int i = 0; + for (i = 0; i < desc.data.ndims-1; i++) { + if (desc.data.format_desc.blocking.strides[i] + < desc.data.format_desc.blocking.strides[i + 1]) { + break; + } + } + if (i == desc.data.ndims-1) { + rslt = true; + } } - } else { - LOG(FATAL) << "Unsupported dimensions: " << desc.data.ndims; - return mkldnn_format_undef; } + return rslt; } -mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd, - mkldnn_memory_format_t format) { - mkldnn::memory::dims dims(pd.desc().data.ndims); +mkldnn::memory::desc GetDesc(const mkldnn::memory::desc &desc, + const mkldnn_format_tag_t &format) { + mkldnn::memory::dims dims(desc.data.ndims); for (size_t i = 0; i < dims.size(); i++) - dims[i] = pd.desc().data.dims[i]; - mkldnn::memory::format cpp_format = static_cast(format); + dims[i] = desc.data.dims[i]; + mkldnn::memory::format_tag cpp_format = static_cast(format); mkldnn::memory::data_type cpp_type = static_cast( - pd.desc().data.data_type); + desc.data.data_type); mkldnn::memory::desc data_md(dims, cpp_type, cpp_format); - return mkldnn::memory::primitive_desc(data_md, pd.get_engine()); + return mkldnn::memory::desc(dims, cpp_type, cpp_format); } template @@ -513,10 +452,11 @@ static bool SimilarArray(const mxnet::NDArray &arr1, const mxnet::NDArray &arr2, std::atomic success(true); #pragma omp parallel for #ifdef _MSC_VER - for (int64_t i = 0; i < arr1.shape().Size(); i++) { + for (int64_t i = 0; i < arr1.shape().Size(); i++) #else - for (size_t i = 0; i < arr1.shape().Size(); i++) { + for (size_t i = 0; i < arr1.shape().Size(); i++) #endif + { if (std::abs(data1[i] - data2[i]) > atol + rtol * std::abs(data2[i])) success.store(false); } @@ -639,6 +579,33 @@ bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs, return dispatched; } +inline static const std::vector GetMKLDNNInputArray(const std::vector &inputs) { + std::vector ret; + ret.reserve(inputs.size()); + for (const auto &in : inputs) { + if (in.IsView() && in.IsMKLDNNData()) { + ret.push_back(in.Reorder2Default()); + } else { + ret.push_back(in); + } + } + return ret; +} + +void MKLDNNRun(mxnet::FComputeEx fn, + const nnvm::NodeAttrs &attrs, + const mxnet::OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + if (CheckMKLDNNInputArrayIsView(inputs)) { + const auto mkldnn_inputs = GetMKLDNNInputArray(inputs); + fn(attrs, ctx, mkldnn_inputs, req, outputs); + } else { + fn(attrs, ctx, inputs, req, outputs); + } +} + } // namespace mxnet #endif diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 61de08fdde23..26637c7c0b65 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -44,54 +44,44 @@ typedef mkldnn::batch_normalization_forward::desc t_bn_f_desc; typedef mkldnn::batch_normalization_backward::primitive_desc t_bn_b_pdesc; typedef mkldnn::batch_normalization_backward::desc t_bn_b_desc; -using mkldnn::use_global_stats; -using mkldnn::use_scale_shift; -using mkldnn::forward_training; -using mkldnn::forward_inference; - -inline static unsigned _GetFlags(const std::vector &in_data, +inline static mkldnn::normalization_flags _GetFlags(const std::vector &in_data, const std::vector &aux_states, const BatchNormParam ¶m, bool is_train_and_not_global_stats) { - unsigned flags = 0U; + mkldnn::normalization_flags flags = static_cast(0U); if (in_data.size() == 3U) { - flags |= use_scale_shift; + flags |= mkldnn::normalization_flags::use_scale_shift; } // aux_states[0]: inMean // aux_states[1]: inVariance if (aux_states.size() == 2U && !is_train_and_not_global_stats) { - flags |= use_global_stats; + flags |= mkldnn::normalization_flags::use_global_stats; } return flags; } -template inline static t_bn_f_pdesc _GetFwd(const mkldnn::memory &data_mem, bool is_train, - DType eps, - unsigned flags) { - auto data_mpd = data_mem.get_primitive_desc(); - auto data_md = data_mpd.desc(); - auto engine = CpuEngine::Get()->get_engine(); + float eps, + mkldnn::normalization_flags flags) { + auto data_md = data_mem.get_desc(); + auto engine = CpuEngine::Get()->get_engine(); if (is_train) { - t_bn_f_desc bnFwd_desc(forward_training, data_md, eps, flags); + t_bn_f_desc bnFwd_desc(mkldnn::prop_kind::forward_training, data_md, eps, flags); return t_bn_f_pdesc(bnFwd_desc, engine); } else { - t_bn_f_desc bnFwd_desc(forward_inference, data_md, eps, flags); + t_bn_f_desc bnFwd_desc(mkldnn::prop_kind::forward_inference, data_md, eps, flags); return t_bn_f_pdesc(bnFwd_desc, engine); } } -template inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem, const mkldnn::memory &diff_mem, - DType eps, - unsigned flags) { - auto data_mpd = data_mem.get_primitive_desc(); - auto data_md = data_mpd.desc(); - auto diff_mpd = diff_mem.get_primitive_desc(); - auto diff_md = diff_mpd.desc(); + float eps, + mkldnn::normalization_flags flags) { + auto data_md = data_mem.get_desc(); + auto diff_md = diff_mem.get_desc(); auto engine = CpuEngine::Get()->get_engine(); t_bn_b_desc bnBwd_desc(mkldnn::prop_kind::backward, diff_md, data_md, eps, flags); @@ -101,18 +91,15 @@ inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem, typedef ParamOpSign MKLDNNBNSignature; class MKLDNNBNForward { - std::shared_ptr data_m; std::shared_ptr weight_m; - std::shared_ptr out_m; - std::shared_ptr mean_m; - std::shared_ptr var_m; std::shared_ptr fwd; bool is_train_and_not_global_stats; t_bn_f_pdesc pd; public: MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train_and_not_global_stats): pd(_pd) { - weight_m.reset(new mkldnn::memory(pd.weights_primitive_desc())); + weight_m.reset(new mkldnn::memory(pd.weights_desc(), CpuEngine::Get()->get_engine())); + fwd.reset(new mkldnn::batch_normalization_forward(pd)); this->is_train_and_not_global_stats = is_train_and_not_global_stats; } @@ -124,59 +111,6 @@ class MKLDNNBNForward { return pd; } - const mkldnn::memory &GetMean() const { - return *mean_m; - } - - const mkldnn::memory &GetVar() const { - return *var_m; - } - - void SetDataHandle(const mkldnn::memory *data, const mkldnn::memory *mean, - const mkldnn::memory *var, const mkldnn::memory *out) { - if (data_m) { - data_m->set_data_handle(data->get_data_handle()); - } else { - data_m.reset(new mkldnn::memory(data->get_primitive_desc(), - data->get_data_handle())); - } - if (out_m) { - out_m->set_data_handle(out->get_data_handle()); - } else { - out_m.reset(new mkldnn::memory(out->get_primitive_desc(), - out->get_data_handle())); - } - if (mean_m) { - mean_m->set_data_handle(mean->get_data_handle()); - } else { - mean_m.reset(new mkldnn::memory(mean->get_primitive_desc(), - mean->get_data_handle())); - } - if (var_m) { - var_m->set_data_handle(var->get_data_handle()); - } else { - var_m.reset(new mkldnn::memory(var->get_primitive_desc(), - var->get_data_handle())); - } - - if (fwd == nullptr) { - if (!is_train_and_not_global_stats) - fwd.reset(new mkldnn::batch_normalization_forward( - pd, *data_m, mkldnn::primitive::at(*mean_m), - mkldnn::primitive::at(*var_m), *weight_m, *out_m)); - else - fwd.reset(new mkldnn::batch_normalization_forward( - pd, mkldnn::primitive::at(*data_m), - mkldnn::primitive::at(*weight_m), *out_m, - *mean_m, *var_m)); - } - } - - void SetDataHandle(const NDArray &data, const NDArray &mean, - const NDArray &var, const mkldnn::memory &out) { - SetDataHandle(data.GetMKLDNNData(), mean.GetMKLDNNData(), var.GetMKLDNNData(), &out); - } - const mkldnn::batch_normalization_forward &GetFwd() const { return *fwd; } @@ -185,7 +119,7 @@ class MKLDNNBNForward { template static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, const OpContext &ctx, const mkldnn::memory *data_mem, - unsigned flags) { + mkldnn::normalization_flags flags) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; #else @@ -193,13 +127,12 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, #endif MKLDNNBNSignature key(param); key.AddSign(ctx.is_train); - key.AddSign(param.use_global_stats); key.AddSign(*data_mem); auto it = fwds.find(key); if (it == fwds.end()) { auto fwd_pd = _GetFwd(*data_mem, ctx.is_train, - (DType) param.eps, flags); + param.eps, flags); MKLDNNBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats); it = AddToCache(&fwds, key, fwd); } @@ -209,7 +142,7 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, template static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, const OpContext &ctx, const NDArray &in_data, - unsigned flags) { + mkldnn::normalization_flags flags) { return GetBNForward(param, ctx, in_data.GetMKLDNNData(), flags); } @@ -220,18 +153,20 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, const std::vector &out_data, const std::vector &aux_states) { TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); - unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats); + mkldnn::normalization_flags flags = _GetFlags(in_data, + aux_states, + param, + ctx.is_train && !param.use_global_stats); const NDArray &data = in_data[batchnorm::kData]; - auto &fwd = GetBNForward(param, ctx, data, flags); - const NDArray &out = out_data[batchnorm::kOut]; + const NDArray &out = out_data[batchnorm::kOut]; // for output memory - auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_primitive_desc()); + auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); // mxnet will always use scale shift. // But if fix_gamma is true, then all scale elements will be set to 1.0f - if (flags & use_scale_shift) { + if (static_cast(flags) & static_cast(mkldnn::normalization_flags::use_scale_shift)) { const NDArray &gamma = in_data[batchnorm::kGamma]; const NDArray &beta = in_data[batchnorm::kBeta]; CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage); @@ -241,7 +176,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, DType* weight_buf = reinterpret_cast(weight_mem.get_data_handle()); nnvm::dim_t channels_ = data.shape()[1]; - CHECK(weight_mem.get_primitive_desc().get_size() == channels_ * sizeof(DType) * 2); + CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(DType) * 2); DType* weight_ptr = gamma.data().dptr(); DType* bias_ptr = beta.data().dptr(); if (!param.fix_gamma) { @@ -249,17 +184,22 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, memcpy(&weight_buf[channels_], bias_ptr, sizeof(weight_buf[0]) * channels_); } else if (IsBNWriting(req[batchnorm::kGamma])) { for (int i = 0; i < channels_; i++) { - weight_buf[i] = (DType)1.0f; - weight_ptr[i] = (DType)1.0f; + weight_buf[i] = static_cast(1.0f); + weight_ptr[i] = static_cast(1.0f); weight_buf[channels_ + i] = bias_ptr[i]; // bias } } else { for (int i = 0; i < channels_; i++) { - weight_buf[i] = (DType)1.0f; + weight_buf[i] = static_cast(1.0f); weight_buf[channels_ + i] = bias_ptr[i]; // bias } } + mkldnn_args_map_t net_args; + net_args[MKLDNN_ARG_SRC] = *data.GetMKLDNNData(); + net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem; + net_args[MKLDNN_ARG_DST] = *out_mem; + if (!ctx.is_train || param.use_global_stats) { DType* omean = out_data[batchnorm::kMean].data().dptr(); DType* ovar = out_data[batchnorm::kVar].data().dptr(); @@ -270,26 +210,21 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, omean[i] = inmean[i]; ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps); } - - fwd.SetDataHandle(data, aux_states[batchnorm::kMovingMean], - aux_states[batchnorm::kMovingVar], - *out_mem); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + net_args[MKLDNN_ARG_MEAN] = *(aux_states[batchnorm::kMovingMean].GetMKLDNNData()); + net_args[MKLDNN_ARG_VARIANCE] = *(aux_states[batchnorm::kMovingVar].GetMKLDNNData()); + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); MKLDNNStream::Get()->Submit(); } else { // training const NDArray &outMean = out_data[batchnorm::kMean]; const NDArray &outVar = out_data[batchnorm::kVar]; - DType* omean = outMean.data().dptr(); - DType* ovar = outVar.data().dptr(); - - fwd.SetDataHandle(data, outMean, outVar, *out_mem); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + net_args[MKLDNN_ARG_MEAN] = *(outMean.GetMKLDNNData()); + net_args[MKLDNN_ARG_VARIANCE] = *(outVar.GetMKLDNNData()); + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); MKLDNNStream::Get()->Submit(); - DType* mean_mem_ptr = reinterpret_cast(fwd.GetMean().get_data_handle()); - DType* var_mem_ptr = reinterpret_cast(fwd.GetVar().get_data_handle()); + + DType* ovar = outVar.data().dptr(); for (int i = 0; i < channels_; i++) { - omean[i] = mean_mem_ptr[i]; - ovar[i] = VARIANCE_TO_INVSTD(var_mem_ptr[i], param.eps); + ovar[i] = VARIANCE_TO_INVSTD(ovar[i], param.eps); } } } else { // no input gamma and beta @@ -299,11 +234,6 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, class MKLDNNBNBackward { std::shared_ptr bwd; - std::shared_ptr data_m; - std::shared_ptr diff_m; - std::shared_ptr gradi_m; - std::shared_ptr mean_m; - std::shared_ptr var_m; const std::shared_ptr weight_m; const std::shared_ptr gradw_m; @@ -311,41 +241,16 @@ class MKLDNNBNBackward { const t_bn_b_pdesc pd; explicit MKLDNNBNBackward(const t_bn_b_pdesc &_pd) - : weight_m(new mkldnn::memory(_pd.weights_primitive_desc())), - gradw_m(new mkldnn::memory(_pd.diff_weights_primitive_desc())), - pd(_pd) {} + : weight_m(new mkldnn::memory(_pd.weights_desc(), CpuEngine::Get()->get_engine())), + gradw_m(new mkldnn::memory(_pd.diff_weights_desc(), CpuEngine::Get()->get_engine())), + pd(_pd) { + bwd.reset(new mkldnn::batch_normalization_backward(pd)); + } const mkldnn::memory &GetWeight() const { return *weight_m; } const mkldnn::memory &GetGradw() const { return *gradw_m; } - void SetDataHandle(const mkldnn::memory &data, const mkldnn::memory &diff, - const NDArray &mean, const mkldnn::memory &var, - const mkldnn::memory &gradi) { - auto mean_ptr = mean.data().dptr_; - if (bwd == nullptr) { - data_m.reset(new mkldnn::memory(data.get_primitive_desc(), - data.get_data_handle())); - diff_m.reset(new mkldnn::memory(diff.get_primitive_desc(), - diff.get_data_handle())); - gradi_m.reset(new mkldnn::memory(gradi.get_primitive_desc(), - gradi.get_data_handle())); - mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), mean_ptr)); - var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(), - var.get_data_handle())); - bwd.reset(new mkldnn::batch_normalization_backward( - pd, *data_m, mkldnn::primitive::at(*mean_m), - mkldnn::primitive::at(*var_m), *diff_m, *weight_m, *gradi_m, - *gradw_m)); - } else { - data_m->set_data_handle(data.get_data_handle()); - diff_m->set_data_handle(diff.get_data_handle()); - gradi_m->set_data_handle(gradi.get_data_handle()); - mean_m->set_data_handle(mean_ptr); - var_m->set_data_handle(var.get_data_handle()); - } - } - const mkldnn::batch_normalization_backward &GetBwd() const { return *bwd; } }; @@ -353,7 +258,7 @@ template static MKLDNNBNBackward &GetBNBackward( const BatchNormParam ¶m, const OpContext &ctx, const NDArray &in_data, const mkldnn::memory &in_mem, const NDArray &diff_data, - const mkldnn::memory &diff_mem, unsigned flags) { + const mkldnn::memory &diff_mem, mkldnn::normalization_flags flags) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map bwds; #else @@ -385,7 +290,10 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, CHECK_EQ(in_data.size(), 3U); CHECK_EQ(out_data.size(), 3U); CHECK_EQ(in_grad.size(), 3U); - unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats); + mkldnn::normalization_flags flags = _GetFlags(in_data, + aux_states, + param, + ctx.is_train && !param.use_global_stats); const NDArray &data = in_data[batchnorm::kData]; const NDArray &diff = out_grad[batchnorm::kOut]; @@ -405,13 +313,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, // MKLDNN batchnorm should run on special layouts. If one of them isn't, we // should reorder them. if (data.IsDefaultData()) - data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc()); + data_mem = data.GetMKLDNNDataReorder(diff_mem->get_desc()); else if (diff.IsDefaultData()) - diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc()); + diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc()); auto &bwd = GetBNBackward(param, ctx, data, *data_mem, diff, *diff_mem, flags); - auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc()); + auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_desc()); - if (flags & use_scale_shift) { + if (static_cast(flags) & static_cast(mkldnn::normalization_flags::use_scale_shift)) { const NDArray &gamma = in_data[batchnorm::kGamma]; const NDArray &beta = in_data[batchnorm::kBeta]; DType *weight_buf = reinterpret_cast(bwd.GetWeight().get_data_handle()); @@ -420,20 +328,27 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, if (!param.fix_gamma) weight_buf[i] = (gamma.data().dptr())[i]; // weight else - weight_buf[i] = (DType)1.0f; + weight_buf[i] = static_cast(1.0f); } for (int i = 0; i < channels_; i++) { weight_buf[channels_ + i] = (beta.data().dptr())[i]; // bias } + mkldnn_args_map_t net_args; + net_args[MKLDNN_ARG_SRC] = *data_mem; + net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem; + net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight(); + net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw(); + net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem; + // training but no input mean and variance if (ctx.is_train && !param.use_global_stats) { DType* moving_mean_ptr = reinterpret_cast(moving_mean.data().dptr()); DType* moving_var_ptr = reinterpret_cast(moving_var.data().dptr()); DType* out_mean_ptr = reinterpret_cast(out_mean.data().dptr()); DType* out_var_ptr = reinterpret_cast(out_var.data().dptr()); - mkldnn::memory var_mem(bwd.pd.variance_primitive_desc()); + mkldnn::memory var_mem(bwd.pd.variance_desc(), CpuEngine::Get()->get_engine()); DType *tmp_var_ptr = reinterpret_cast(var_mem.get_data_handle()); DType minus_mom = (1.0f - param.momentum); @@ -445,13 +360,14 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, moving_var_ptr[i] = moving_var_ptr[i] * param.momentum + variance * minus_mom; } - bwd.SetDataHandle(*data_mem, *diff_mem, out_mean, var_mem, *gradi_mem); - MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); + net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData()); + net_args[MKLDNN_ARG_VARIANCE] = var_mem; + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); MKLDNNStream::Get()->Submit(); } else { - bwd.SetDataHandle(*data_mem, *diff_mem, moving_mean, - *moving_var.GetMKLDNNData(), *gradi_mem); - MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); + net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData()); + net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData()); + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/nn/mkldnn/mkldnn_concat-inl.h b/src/operator/nn/mkldnn/mkldnn_concat-inl.h index d3866cc3d23d..ff47ef35f98f 100644 --- a/src/operator/nn/mkldnn/mkldnn_concat-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_concat-inl.h @@ -20,7 +20,7 @@ /*! * \file mkldnn_concat-inl.h * \brief - * \author Wenting Jiang + * \author */ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_ @@ -40,25 +40,20 @@ class MKLDNNConcatFwd { public: mkldnn::concat::primitive_desc fwd_pd; - MKLDNNConcatFwd(int concat_dim, const std::vector &data_md) - : fwd_pd(concat_dim, data_md) { - data.resize(data_md.size()); + MKLDNNConcatFwd(int concat_dim, const std::vector &data_md) + : fwd_pd(concat_dim, data_md, CpuEngine::Get()->get_engine()) { + fwd_ = std::make_shared(fwd_pd); } - void SetNewMem(const std::vector &in_data, const mkldnn::memory &output); - const mkldnn::concat &GetFwd() const; private: - std::shared_ptr fwd; - std::vector> data; - std::vector data_mem; - std::shared_ptr out; + std::shared_ptr fwd_; }; static MKLDNNConcatFwd &GetConcatForward( int concat_dim, const std::vector &in_data, - const std::vector &data_md) { + const std::vector &data_md) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; #else diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc index 7b266efc2a14..aa30ffc557a1 100644 --- a/src/operator/nn/mkldnn/mkldnn_concat.cc +++ b/src/operator/nn/mkldnn/mkldnn_concat.cc @@ -20,7 +20,7 @@ /*! * \file mkldnn_concat.cc * \brief - * \author Wenting Jiang + * \author */ #if MXNET_USE_MKLDNN == 1 @@ -29,28 +29,7 @@ namespace mxnet { namespace op { -void MKLDNNConcatFwd::SetNewMem(const std::vector &in_data, - const mkldnn::memory &output) { - CHECK_EQ(in_data.size(), data.size()); - for (size_t i = 0; i < data.size(); i++) { - if (this->data[i] == nullptr) { - this->data[i] = std::shared_ptr( - new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle())); - this->data_mem.push_back(*this->data[i]); - } else { - this->data[i]->set_data_handle(in_data[i]->get_data_handle()); - } - } - if (this->out == nullptr) - this->out = std::shared_ptr( - new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle())); - else - this->out->set_data_handle(output.get_data_handle()); - - if (this->fwd == nullptr) fwd.reset(new mkldnn::concat(fwd_pd, data_mem, *out)); -} - -const mkldnn::concat &MKLDNNConcatFwd::GetFwd() const { return *fwd; } +const mkldnn::concat &MKLDNNConcatFwd::GetFwd() const { return *fwd_; } void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, @@ -58,24 +37,28 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &out_data) { TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]); const ConcatParam& param = nnvm::get(attrs.parsed); - int num_in_data = param.num_args; - int concat_dim = param.dim; - std::vector data_md; + const int num_in_data = param.num_args; + const int concat_dim = param.dim; + std::vector data_md; std::vector data_mem; data_md.reserve(num_in_data); data_mem.reserve(num_in_data); for (int i = 0; i < num_in_data; i++) { const mkldnn::memory *tmp_mem = in_data[i].GetMKLDNNData(); - mkldnn::memory::primitive_desc tmp_pd = tmp_mem->get_primitive_desc(); - data_md.push_back(tmp_pd); + mkldnn::memory::desc tmp_md = tmp_mem->get_desc(); + data_md.push_back(tmp_md); data_mem.push_back(tmp_mem); } MKLDNNConcatFwd &fwd = GetConcatForward(concat_dim, in_data, data_md); mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data[concat_enum::kOut], - fwd.fwd_pd.dst_primitive_desc(), + fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]); - fwd.SetNewMem(data_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + std::unordered_map net_args; + net_args.insert({MKLDNN_ARG_DST, *out_mem.second}); + for (int i = 0; i < num_in_data; i++) { + net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *data_mem[i]}); + } + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); CommitOutput(out_data[concat_enum::kOut], out_mem); MKLDNNStream::Get()->Submit(); } @@ -86,11 +69,9 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector& outputs) { TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]); const ConcatParam& param = nnvm::get(attrs.parsed); - int num_in_data = param.num_args; - int axis_ = param.dim; - auto engine = CpuEngine::Get()->get_engine(); - auto gz_mem = inputs[0].GetMKLDNNData(); - mkldnn::memory::primitive_desc gz_pd = gz_mem->get_primitive_desc(); + const int num_in_data = param.num_args; + const int axis = param.dim; + const auto gradz_mem = inputs[0].GetMKLDNNData(); /* init the offset */ mkldnn::memory::dims offsets(outputs[0].shape().ndim()); for (auto &v : offsets) { @@ -99,22 +80,25 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, for (int i = 0; i < num_in_data; i++) { mkldnn::memory::dims diff_src_tz(outputs[i].shape().begin(), outputs[i].shape().end()); - auto diff_src_mpd = outputs[i].GetMKLDNNData()->get_primitive_desc(); - auto gradi_mem_ = CreateMKLDNNMem(outputs[i], diff_src_mpd, req[i]); - // create view from gy to gxs[i] - std::shared_ptr view_pd; - view_pd.reset(new mkldnn::view::primitive_desc(gz_pd, diff_src_tz, offsets)); - // create reorder primitive from gy to gxs[i] - mkldnn::reorder::primitive_desc reorder_pd( - view_pd.get()->dst_primitive_desc(), diff_src_mpd); - offsets[axis_] += diff_src_tz[axis_]; - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder( - reorder_pd, *gz_mem, *gradi_mem_.second)); - CommitOutput(outputs[i], gradi_mem_); + auto diff_src_md = outputs[i].GetMKLDNNData()->get_desc(); + auto gradi_mem = CreateMKLDNNMem(outputs[i], diff_src_md, req[i]); + + auto from_md = gradz_mem->get_desc().submemory_desc(diff_src_tz, offsets); + auto from_mem = new mkldnn::memory(from_md, gradz_mem->get_engine(), + gradz_mem->get_data_handle()); + offsets[axis] += diff_src_tz[axis]; + + std::unordered_map net_args({ + {MKLDNN_ARG_FROM, *gradz_mem}, + {MKLDNN_ARG_TO, *gradi_mem.second} + }); + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(*from_mem, *gradi_mem.second), net_args); + CommitOutput(outputs[i], gradi_mem); } + MKLDNNStream::Get()->Submit(); } } // namespace op } // namespace mxnet -#endif +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index 880b9d19cd81..ac2d3169340e 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -79,45 +79,28 @@ struct MKLDNNConvFullParam { MKLDNNPostEltwiseParam postsum_act_param; }; -mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam ¶m, - const bool is_train, - const NDArray &data, - const NDArray &weights, - const NDArray *bias, - const NDArray &output); +std::shared_ptr GetConvFwdImpl( + const ConvolutionParam ¶m, const bool is_train, const NDArray &data, const NDArray &weight, + const NDArray *bias, const NDArray &output); class MKLDNNConvForward { public: - mkldnn::convolution_forward::primitive_desc fwd_pd; - MKLDNNConvForward(const MKLDNNConvFullParam ¶m, const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output); - - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, - const mkldnn::memory *bias, const mkldnn::memory &output); + const NDArray &weight, const NDArray *bias, const NDArray &output); - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { - this->data_->set_data_handle(data.get_data_handle()); - this->out_->set_data_handle(output.get_data_handle()); - } + const mkldnn::convolution_forward &GetFwd() const { return *fwd_; } - const mkldnn::convolution_forward &GetFwd() const { - return *fwd_; - } + const mkldnn::convolution_forward::primitive_desc &GetPd() const { return *pd_; } private: std::shared_ptr fwd_; - std::shared_ptr data_; - std::shared_ptr weight_; - std::shared_ptr bias_; - std::shared_ptr out_; + std::shared_ptr pd_; }; typedef ParamOpSign MKLDNNConvSignature; -MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, - const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, +MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam ¶m, const bool is_train, + const NDArray &data, const NDArray &weight, const NDArray *bias, const NDArray &output); void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, @@ -127,6 +110,36 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, const std::vector &req, const std::vector &out_data); +void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); + +class MKLDNNConvBackward { + public: + MKLDNNConvBackward(const MKLDNNConvFullParam ¶m, const NDArray &data, const NDArray &weight, + const NDArray *bias, const NDArray &output); + + const mkldnn::convolution_backward_data &GetBwdData() const { return *bwd_data_; } + + const mkldnn::convolution_backward_weights &GetBwdWeights() const { return *bwd_weight_; } + + const mkldnn::convolution_backward_data::primitive_desc &GetDataPd() const { + return *bwd_data_pd_; + } + + const mkldnn::convolution_backward_weights::primitive_desc &GetWeightsPd() const { + return *bwd_weight_pd_; + } + + private: + std::shared_ptr bwd_data_pd_; + std::shared_ptr bwd_weight_pd_; + std::shared_ptr bwd_data_; + std::shared_ptr bwd_weight_; +}; + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 9cab2dd0e2b3..ada42a22cc8c 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -21,8 +21,7 @@ * \file mkldnn_convolution.cc * \brief * \author Da Zheng -*/ - + */ #if MXNET_USE_MKLDNN == 1 @@ -45,8 +44,10 @@ bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { (input.shape().ndim() == 4)); } -mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam ¶m, - const bool is_train, const NDArray &data, +std::shared_ptr GetConvFwdImpl( + const MKLDNNConvFullParam ¶m, + const bool is_train, + const NDArray &data, const NDArray &weights, const NDArray *bias, const NDArray &output) { @@ -57,7 +58,7 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP auto bias_md = bias ? (param.mkldnn_param.quantized ? GetMemDesc(*bias, mshadow::kInt32) : GetMemDesc(*bias)) : mkldnn::memory::desc{ - {}, mkldnn::memory::data_type::data_undef, mkldnn::memory::format::any}; + {}, mkldnn::memory::data_type::undef, mkldnn::memory::format_tag::any}; auto bias_md_ptr = bias ? &bias_md : nullptr; mkldnn::memory::dims strides(param.conv_param.kernel.ndim()); @@ -98,19 +99,19 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP if (param.mkldnn_param.quantized && param.requantize_scales.size()) { int mask = (param.requantize_scales.size() > 1) ? 2 : 0; attr.set_output_scales(mask, param.requantize_scales); - attr.set_int_output_round_mode(round_nearest); } auto GetConvFwdPd = [¶m, &data, &weights, &output, &attr](const mkldnn::convolution_forward::desc &desc) { auto engine = CpuEngine::Get()->get_engine(); try { - auto conv_pd = mkldnn::convolution_forward::primitive_desc(desc, attr, engine); - while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || + auto conv_pd = + std::make_shared(desc, attr, engine); + while (conv_pd->dst_desc().get_size() != GetArraySize(output) || + conv_pd->src_desc().get_size() != GetArraySize(data) || (!param.mkldnn_param.quantized && - conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights))) { + conv_pd->weights_desc().get_size() != GetArraySize(weights))) { // next_impl() will visit desc and engine, please make sure they are still alive here. - CHECK(conv_pd.next_impl()) << "No convolution implementation for this request."; + CHECK(conv_pd->next_impl()) << "No convolution implementation for this request."; } return conv_pd; } catch (mkldnn::error &e) { @@ -126,13 +127,12 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP if (param.conv_param.dilate.ndim() == 0 && bias_md_ptr == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, - weight_md, out_md, strides, padding, padding, - mkldnn::padding_kind::zero); + weight_md, out_md, strides, padding, padding); return GetConvFwdPd(desc); } else if (param.conv_param.dilate.ndim() == 0) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, *bias_md_ptr, out_md, strides, padding, - padding, mkldnn::padding_kind::zero); + padding); return GetConvFwdPd(desc); } else { mkldnn::memory::dims dilates(param.conv_param.kernel.ndim()); @@ -147,23 +147,22 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP } if (bias_md_ptr == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, - weight_md, out_md, strides, dilates, padding, padding, - mkldnn::padding_kind::zero); + weight_md, out_md, strides, dilates, padding, padding); return GetConvFwdPd(desc); } else { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, *bias_md_ptr, out_md, strides, dilates, - padding, padding, mkldnn::padding_kind::zero); + padding, padding); return GetConvFwdPd(desc); } } } -static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData( - const ConvolutionParam& param, const NDArray &data, const NDArray &weights, +static std::shared_ptr GetConvBwdData( + const ConvolutionParam ¶m, const NDArray &data, const NDArray &weight, const NDArray &output, const mkldnn::convolution_forward::primitive_desc &fwd_pd) { auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.num_group); + auto weight_md = GetWeightDesc(weight, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); mkldnn::memory::dims strides(param.kernel.ndim()); @@ -187,21 +186,29 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData( << ", supporting only 1 or 2."; } - // MKL-DNN introduced padded formats since 0.15 which require more memory - // for computation compared with the actual tensor size. Currently, MKL-DNN - // operators are still reusing those memory from memory planning and the - // memory size may smaller than what MKL-DNN kernels require. So here we need - // select suboptimal kernel for computation according to tensor sizes. - if (param.dilate.ndim() == 0) { - mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, - data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd); - while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.diff_src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; + auto GetConvBwdDataPd = [&data, &weight, &output, + &fwd_pd](const mkldnn::convolution_backward_data::desc &desc) { + auto engine = CpuEngine::Get()->get_engine(); + try { + auto conv_pd = + std::make_shared(desc, engine, fwd_pd); + while (conv_pd->diff_dst_desc().get_size() != GetArraySize(output) || + conv_pd->diff_src_desc().get_size() != GetArraySize(data) || + conv_pd->weights_desc().get_size() != GetArraySize(weight)) { + // next_impl() will visit desc and engine, please make sure they are still alive here. + CHECK(conv_pd->next_impl()) << "No convolution backward implementation for this request."; + } + return conv_pd; + } catch (mkldnn::error &e) { + LOG(ERROR) << e.message; + throw; } - return conv_pd; + }; + + if (param.dilate.ndim() == 0) { + mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, data_md, + weight_md, out_md, strides, padding, padding); + return GetConvBwdDataPd(desc); } else { mkldnn::memory::dims dilates(param.kernel.ndim()); if (param.dilate.ndim() == 1) { @@ -213,25 +220,18 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData( LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size " << param.dilate.ndim() << ", supporting only 1 or 2."; } - mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, - data_md, weight_md, out_md, strides, dilates, padding, padding, - mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd); - while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.diff_src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; - } - return conv_pd; + mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, data_md, + weight_md, out_md, strides, dilates, padding, + padding); + return GetConvBwdDataPd(desc); } } -static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( - const ConvolutionParam& param, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output, - const mkldnn::convolution_forward::primitive_desc &fwd_pd) { +static std::shared_ptr GetConvBwdWeights( + const ConvolutionParam ¶m, const NDArray &data, const NDArray &weight, const NDArray *bias, + const NDArray &output, const mkldnn::convolution_forward::primitive_desc &fwd_pd) { auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.num_group); + auto weight_md = GetWeightDesc(weight, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); mkldnn::memory::dims strides(param.kernel.ndim()); @@ -255,33 +255,35 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( << ", supporting only 1 or 2."; } - // MKL-DNN introduced padded formats since 0.15 which require more memory - // for computation compared with the actual tensor size. Currently, MKL-DNN - // operators are still reusing those memory from memory planning and the - // memory size may smaller than what MKL-DNN kernels require. So here we need - // select suboptimal kernel for computation according to tensor sizes. - if (param.dilate.ndim() == 0 && bias == nullptr) { - mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, - data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); - while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.diff_weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; + auto GetConvBwdWeightsPd = [&data, &weight, &output, + &fwd_pd](const mkldnn::convolution_backward_weights::desc &desc) { + auto engine = CpuEngine::Get()->get_engine(); + try { + auto conv_pd = std::make_shared( + desc, engine, fwd_pd); + while (conv_pd->diff_dst_desc().get_size() != GetArraySize(output) || + conv_pd->src_desc().get_size() != GetArraySize(data) || + conv_pd->diff_weights_desc().get_size() != GetArraySize(weight)) { + // next_impl() will visit desc and engine, please make sure they are still alive here. + CHECK(conv_pd->next_impl()) << "No convolution backward implementation for this request."; + } + return conv_pd; + } catch (mkldnn::error &e) { + LOG(ERROR) << e.message; + throw; } - return conv_pd; + }; + + if (param.dilate.ndim() == 0 && bias == nullptr) { + mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, data_md, + weight_md, out_md, strides, padding, padding); + return GetConvBwdWeightsPd(desc); } else if (param.dilate.ndim() == 0) { auto bias_md = GetMemDesc(*bias); - mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, - data_md, weight_md, bias_md, out_md, strides, padding, padding, - mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); - while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.diff_weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; - } - return conv_pd; + mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, data_md, + weight_md, bias_md, out_md, strides, padding, + padding); + return GetConvBwdWeightsPd(desc); } else { mkldnn::memory::dims dilates(param.kernel.ndim()); if (param.dilate.ndim() == 1) { @@ -295,313 +297,154 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( } if (bias == nullptr) { mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, - data_md, weight_md, out_md, strides, dilates, padding, padding, - mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); - while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.diff_weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; - } - return conv_pd; + data_md, weight_md, out_md, strides, dilates, + padding, padding); + return GetConvBwdWeightsPd(desc); } else { auto bias_md = GetMemDesc(*bias); mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, - data_md, weight_md, bias_md, out_md, - strides, dilates, padding, padding, - mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); - while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.diff_weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; - } - return conv_pd; + data_md, weight_md, bias_md, out_md, strides, + dilates, padding, padding); + return GetConvBwdWeightsPd(desc); } } } MKLDNNConvForward::MKLDNNConvForward(const MKLDNNConvFullParam ¶m, const bool is_train, - const NDArray &data, const NDArray &weights, + const NDArray &data, const NDArray &weight, const NDArray *bias, const NDArray &output) - : fwd_pd(GetConvFwdImpl(param, is_train, data, weights, bias, output)) { - data_ = std::make_shared(fwd_pd.src_primitive_desc(), nullptr); - weight_ = std::make_shared(fwd_pd.weights_primitive_desc(), nullptr); - out_ = std::make_shared(fwd_pd.dst_primitive_desc(), nullptr); - if (bias) { - bias_ = std::make_shared(fwd_pd.bias_primitive_desc(), nullptr); - fwd_ = std::make_shared(fwd_pd, *this->data_, *this->weight_, - *this->bias_, *this->out_); - } else { - fwd_ = std::make_shared(fwd_pd, *this->data_, *this->weight_, - *this->out_); - } + : pd_(GetConvFwdImpl(param, is_train, data, weight, bias, output)) { + fwd_ = std::make_shared(GetPd()); } -void MKLDNNConvForward::SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, - const mkldnn::memory *bias, const mkldnn::memory &output) { - data_->set_data_handle(data.get_data_handle()); - weight_->set_data_handle(weight.get_data_handle()); - out_->set_data_handle(output.get_data_handle()); - if (bias != nullptr) bias_->set_data_handle(bias->get_data_handle()); -} - -MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, - const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, +MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam ¶m, const bool is_train, + const NDArray &data, const NDArray &weight, const NDArray *bias, const NDArray &output) { + using conv_fwd_map = std::unordered_map; #if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map fwds; + static thread_local conv_fwd_map fwds; #else - static MX_THREAD_LOCAL std::unordered_map fwds; + static MX_THREAD_LOCAL conv_fwd_map fwds; #endif - MKLDNNConvSignature key(param); + // TODO(zhennan): Hash conv_param for now, need to hash full param if we want to enable cache for + // fused conv + MKLDNNConvSignature key(param.conv_param); key.AddSign(is_train); - // Here we can sign the conv op with NDArray because conv primitive will - // decide the right layout for the, so we only need to get the shape and the - // data type of the arrays. + // Here we can sign the conv op with NDArray because conv primitive will decide the right layout + // for the, so we only need to get the shape and the data type of the arrays. key.AddSign(data); - key.AddSign(weights); + key.AddSign(weight); key.AddSign(output); - if (bias) - key.AddSign(*bias); + if (bias) key.AddSign(*bias); auto it = fwds.find(key); if (it == fwds.end()) { - MKLDNNConvFullParam full_param; - full_param.conv_param = param; - full_param.mkldnn_param.Init(std::unordered_map()); - MKLDNNConvForward fwd(full_param, is_train, data, weights, bias, output); + auto fwd = MKLDNNConvForward(param, is_train, data, weight, bias, output); it = AddToCache(&fwds, key, fwd); } return it->second; } -void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, - const OpContext &ctx, +void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, const OpContext &ctx, MKLDNNConvForward *fwd, const std::vector &in_data, const std::vector &req, const std::vector &out_data) { TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); - auto data = in_data[conv::kData]; - if (data.IsView() && data.IsMKLDNNData()) - data = data.Reorder2Default(); - - auto weight = in_data[conv::kWeight]; - if (weight.IsView() && weight.IsMKLDNNData()) - weight = weight.Reorder2Default(); - + auto &data = in_data[conv::kData]; + auto &weight = in_data[conv::kWeight]; bool no_bias = param.conv_param.no_bias && !param.mkldnn_param.with_bn; - auto data_mem = data.GetMKLDNNDataReorder( - fwd->fwd_pd.src_primitive_desc()); + auto data_mem = data.GetMKLDNNDataReorder(fwd->GetPd().src_desc()); const mkldnn::memory *weight_mem; if (ctx.is_train) { - // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it - // to the default format for now. + // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it to the default format + // for now. if (weight.IsMKLDNNData()) - // This asks the engine to change the layout of the weight array after - // it's used. + // This asks the engine to change the layout of the weight array after it's used. weight.Reorder2DefaultAsync(); - weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), - param.conv_param.num_group); + weight_mem = GetWeights(weight, fwd->GetPd().weights_desc(), param.conv_param.num_group); } else { - // For inference, we want to reorder the weight array so we don't need to - // reorder data every time. + // For inference, we want to reorder the weight array so we don't need to reorder data every + // time. if (weight.IsDefaultData()) { - // We also need to modify the layout on the original weight array. The - // data conversion happens after the weight array is used. - weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc()); - weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), - param.conv_param.num_group); - + // We also need to modify the layout on the original weight array. The data conversion happens + // after the weight array is used. + weight.MKLDNNDataReorderAsync(fwd->GetPd().weights_desc()); + weight_mem = GetWeights(weight, fwd->GetPd().weights_desc(), param.conv_param.num_group); } else { weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc()); + CHECK(weight_mem->get_desc() == fwd->GetPd().weights_desc()); } } mkldnn_output_t out_mem; if (param.mkldnn_param.with_sum) { - out_mem = mkldnn_output_t( - OutDataOp::Noop, - const_cast(out_data[conv::kOut].GetMKLDNNData())); + out_mem = mkldnn_output_t(OutDataOp::Noop, + const_cast(out_data[conv::kOut].GetMKLDNNData())); } else { - out_mem = CreateMKLDNNMem(out_data[conv::kOut], - fwd->fwd_pd.dst_primitive_desc(), req[conv::kOut]); + out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd->GetPd().dst_desc(), req[conv::kOut]); } - const mkldnn::memory *bias_mem = nullptr; + mkldnn_args_map_t net_args; if (!no_bias) { - bias_mem = in_data[conv::kBias].GetMKLDNNData(); + const mkldnn::memory *bias_mem = in_data[conv::kBias].GetMKLDNNData(); + net_args.insert({MKLDNN_ARG_BIAS, *bias_mem}); } - fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd()); + net_args.insert({MKLDNN_ARG_SRC, *data_mem}); + net_args.insert({MKLDNN_ARG_WEIGHTS, *weight_mem}); + net_args.insert({MKLDNN_ARG_DST, *out_mem.second}); + MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), net_args); CommitOutput(out_data[conv::kOut], out_mem); MKLDNNStream::Get()->Submit(); } -void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, +void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data) { MKLDNNConvFullParam param; param.conv_param = nnvm::get(attrs.parsed); param.mkldnn_param.Init(std::unordered_map()); - auto &fwd = GetConvFwd( - param.conv_param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], - param.conv_param.no_bias ? nullptr : &in_data[conv::kBias], - out_data[conv::kOut]); + auto &fwd = + GetConvFwd(param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], + param.conv_param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); MKLDNNConvolutionForwardFullFeature(param, ctx, &fwd, in_data, req, out_data); } -class MKLDNNConvBackward { - std::shared_ptr bwd_data; - std::shared_ptr bwd_weight; - // conv::kData - std::shared_ptr out_grad; - std::shared_ptr in_grad; - std::shared_ptr weight; - // conv::kWeight - std::shared_ptr data; - std::shared_ptr output; - std::shared_ptr in_grad_weight; - std::shared_ptr in_grad_bias; - - public: - mkldnn::convolution_backward_data::primitive_desc bwdData_pd; - mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd; - - MKLDNNConvBackward( - const ConvolutionParam ¶m, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output, - const mkldnn::convolution_forward::primitive_desc &fwd_pd): - bwdData_pd(GetConvBwdData(param, data, weights, output, fwd_pd)), - bwdWeights_pd(GetConvBwdWeights(param, data, weights, bias, output, fwd_pd)) { - } - - void SetDataNewMem(const mkldnn::memory &out_grad, const mkldnn::memory &weight, - const mkldnn::memory &in_grad) { - if (this->out_grad == nullptr) - this->out_grad = std::shared_ptr(new mkldnn::memory( - bwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); - else - this->out_grad->set_data_handle(out_grad.get_data_handle()); - if (this->in_grad == nullptr) - this->in_grad = std::shared_ptr(new mkldnn::memory( - bwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle())); - else - this->in_grad->set_data_handle(in_grad.get_data_handle()); - if (this->weight == nullptr) - this->weight = std::shared_ptr(new mkldnn::memory( - bwdData_pd.weights_primitive_desc(), weight.get_data_handle())); - else - this->weight->set_data_handle(weight.get_data_handle()); - if (this->bwd_data == nullptr) - this->bwd_data = std::shared_ptr( - new mkldnn::convolution_backward_data( - this->bwdData_pd, mkldnn::primitive::at(*this->out_grad), - mkldnn::primitive::at(*this->weight), *this->in_grad)); - } - - void SetWeightNewMem(const mkldnn::memory &data, - const mkldnn::memory &out_grad, - const mkldnn::memory &in_grad_weight) { - if (this->data == nullptr) - this->data = std::shared_ptr(new mkldnn::memory( - bwdWeights_pd.src_primitive_desc(), data.get_data_handle())); - else - this->data->set_data_handle(data.get_data_handle()); - if (this->output == nullptr) - this->output = std::shared_ptr(new mkldnn::memory( - bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); - else - this->output->set_data_handle(out_grad.get_data_handle()); - if (this->in_grad_weight == nullptr) - this->in_grad_weight = std::shared_ptr( - new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(), - in_grad_weight.get_data_handle())); - else - this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); - - if (this->bwd_weight == nullptr) - this->bwd_weight = std::shared_ptr( - new mkldnn::convolution_backward_weights( - this->bwdWeights_pd, mkldnn::primitive::at(*this->data), - mkldnn::primitive::at(*this->output), *this->in_grad_weight)); - } - - void SetWeightNewMem(const mkldnn::memory &data, - const mkldnn::memory &out_grad, - const mkldnn::memory &in_grad_weight, - const mkldnn::memory &in_grad_bias) { - if (this->data == nullptr) - this->data = std::shared_ptr(new mkldnn::memory( - bwdWeights_pd.src_primitive_desc(), data.get_data_handle())); - else - this->data->set_data_handle(data.get_data_handle()); - if (this->output == nullptr) - this->output = std::shared_ptr(new mkldnn::memory( - bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); - else - this->output->set_data_handle(out_grad.get_data_handle()); - if (this->in_grad_weight == nullptr) - this->in_grad_weight = std::shared_ptr( - new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(), - in_grad_weight.get_data_handle())); - else - this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); - - if (this->in_grad_bias == nullptr) - this->in_grad_bias = std::shared_ptr( - new mkldnn::memory(bwdWeights_pd.diff_bias_primitive_desc(), - in_grad_bias.get_data_handle())); - else - this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle()); - if (this->bwd_weight == nullptr) - this->bwd_weight = std::shared_ptr( - new mkldnn::convolution_backward_weights( - this->bwdWeights_pd, mkldnn::primitive::at(*this->data), - mkldnn::primitive::at(*this->output), *this->in_grad_weight, - *this->in_grad_bias)); - } - - const mkldnn::convolution_backward_data &GetBwdData() const { - return *bwd_data; - } - - const mkldnn::convolution_backward_weights &GetBwdWeights() const { - return *bwd_weight; - } -}; +MKLDNNConvBackward::MKLDNNConvBackward(const MKLDNNConvFullParam ¶m, const NDArray &data, + const NDArray &weight, const NDArray *bias, + const NDArray &output) { + const auto fwd_pd = GetConvFwdImpl(param, true, data, weight, bias, output); + bwd_data_pd_ = GetConvBwdData(param.conv_param, data, weight, output, *fwd_pd); + bwd_weight_pd_ = GetConvBwdWeights(param.conv_param, data, weight, bias, output, *fwd_pd); + bwd_data_ = std::make_shared(GetDataPd()); + bwd_weight_ = std::make_shared(GetWeightsPd()); +} -static inline MKLDNNConvBackward &GetConvBwd( - const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weights, - const NDArray *bias, const NDArray &output, - const mkldnn::convolution_forward::primitive_desc &fwd_pd) { +static inline MKLDNNConvBackward &GetConvBwd(const MKLDNNConvFullParam ¶m, const NDArray &data, + const NDArray &weight, const NDArray *bias, + const NDArray &output) { + using mkldnn_conv_bwd_map = std::unordered_map; #if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map bwds; + static thread_local mkldnn_conv_bwd_map bwds; #else - static MX_THREAD_LOCAL std::unordered_map bwds; + static MX_THREAD_LOCAL mkldnn_conv_bwd_map bwds; #endif - const ConvolutionParam& param = nnvm::get(attrs.parsed); - MKLDNNConvSignature key(param); - // Here we can sign the conv op with NDArray because conv primitive will - // decide the right layout for the, so we only need to get the shape and the - // data type of the arrays. + // TODO(zhennan): Hash conv_param for now, need to hash full param if we want to enable cache for + // fused conv + MKLDNNConvSignature key(param.conv_param); + // Here we can sign the conv op with NDArray because conv primitive will decide the right layout + // for the, so we only need to get the shape and the data type of the arrays. key.AddSign(data); - key.AddSign(weights); + key.AddSign(weight); key.AddSign(output); - if (bias) - key.AddSign(*bias); - + if (bias) key.AddSign(*bias); auto it = bwds.find(key); if (it == bwds.end()) { - MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd); + auto bwd = MKLDNNConvBackward(param, data, weight, bias, output); it = AddToCache(&bwds, key, bwd); } return it->second; @@ -617,65 +460,49 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct full_param.conv_param = nnvm::get(attrs.parsed); full_param.mkldnn_param.Init(std::unordered_map()); - auto data = inputs[conv::kData + 1]; - if (data.IsView() && data.IsMKLDNNData()) - data = data.Reorder2Default(); + auto &data = inputs[conv::kData + 1]; + auto &weight = inputs[conv::kWeight + 1]; + const auto *bias = full_param.conv_param.no_bias ? nullptr : &inputs[conv::kBias + 1]; + auto &out_grad = inputs[conv::kOut]; - auto weight = inputs[conv::kWeight + 1]; - if (weight.IsView() && weight.IsMKLDNNData()) - weight = weight.Reorder2Default(); - - const NDArray* bias = full_param.conv_param.no_bias ? nullptr : &inputs[conv::kBias + 1]; - - auto out_grad = inputs[conv::kOut]; - if (out_grad.IsView() && out_grad.IsMKLDNNData()) - out_grad = out_grad.Reorder2Default(); - - mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl( - full_param, ctx.is_train, data, weight, bias, out_grad); const ConvolutionParam ¶m = full_param.conv_param; CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace"; - MKLDNNConvBackward &convBwd = GetConvBwd(attrs, data, - weight, bias, out_grad, fwd_pd); - auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - convBwd.bwdData_pd.diff_dst_primitive_desc()); + MKLDNNConvBackward &convBwd = GetConvBwd(full_param, data, weight, bias, out_grad); + auto out_grad_mem = out_grad.GetMKLDNNDataReorder(convBwd.GetDataPd().diff_dst_desc()); if (req[conv::kData]) { - auto weight_mem = GetWeights(weight, - convBwd.bwdData_pd.weights_primitive_desc(), param.num_group); - auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData], - convBwd.bwdData_pd.diff_src_primitive_desc(), req[conv::kData]); - convBwd.SetDataNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second); - MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData()); + auto weight_mem = GetWeights(weight, convBwd.GetDataPd().weights_desc(), param.num_group); + auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData], convBwd.GetDataPd().diff_src_desc(), + req[conv::kData]); + MKLDNNStream::Get()->RegisterPrimArgs(convBwd.GetBwdData(), + {{MKLDNN_ARG_DIFF_DST, *out_grad_mem}, + {MKLDNN_ARG_WEIGHTS, *weight_mem}, + {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}}); CommitOutput(in_grad[conv::kData], in_grad_mem); } if (req[conv::kWeight] || req[conv::kBias]) { - MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, data, - weight, bias, out_grad, fwd_pd); - if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() != - convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc()) - out_grad_mem = out_grad.GetMKLDNNDataReorder( - convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc()); - auto data_mem = data.GetMKLDNNDataReorder( - convBwdWeight.bwdWeights_pd.src_primitive_desc()); + if (convBwd.GetDataPd().diff_dst_desc() != convBwd.GetWeightsPd().diff_dst_desc()) + out_grad_mem = out_grad.GetMKLDNNDataReorder(convBwd.GetWeightsPd().diff_dst_desc()); + auto data_mem = data.GetMKLDNNDataReorder(convBwd.GetWeightsPd().src_desc()); auto in_grad_weight = CreateMKLDNNWeightGrad( - in_grad[conv::kWeight], - convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(), - req[conv::kWeight]); - if (param.no_bias) { - convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, - *in_grad_weight.second); - MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); - } else { - auto in_grad_bias = CreateMKLDNNMem( - in_grad[conv::kBias], - convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]); - convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, - *in_grad_weight.second, *in_grad_bias.second); - MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); - CommitOutput(in_grad[conv::kBias], in_grad_bias); + in_grad[conv::kWeight], convBwd.GetWeightsPd().diff_weights_desc(), req[conv::kWeight]); + + mkldnn_args_map_t net_args = {{MKLDNN_ARG_DIFF_DST, *out_grad_mem}, + {MKLDNN_ARG_SRC, *data_mem}, + {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second}}; + mkldnn_output_t in_grad_bias; + if (!param.no_bias) { + in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias], + convBwd.GetWeightsPd().diff_bias_desc(), + req[conv::kBias]); + net_args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second}); } + MKLDNNStream::Get()->RegisterPrimArgs(convBwd.GetBwdWeights(), net_args); CommitOutput(in_grad[conv::kWeight], in_grad_weight); + // CommitOutput Should run after RegisterPrimArgs for memory dependency + if (!param.no_bias) { + CommitOutput(in_grad[conv::kBias], in_grad_bias); + } } MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/nn/mkldnn/mkldnn_copy.cc b/src/operator/nn/mkldnn/mkldnn_copy.cc index a7c280e1e713..cf8daa4b45df 100644 --- a/src/operator/nn/mkldnn/mkldnn_copy.cc +++ b/src/operator/nn/mkldnn/mkldnn_copy.cc @@ -18,12 +18,11 @@ */ /*! - * \file mkldnn_softmax.cc + * \file mkldnn_copy.cc * \brief - * \author Da Zheng + * \author */ -#include "../softmax-inl.h" #include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" @@ -47,9 +46,9 @@ void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx, // We should try and force the input memory has the same format // as the input output. If not, we'll have to reorder memory. auto out_mem = out_data.GetMKLDNNData(); - in_mem = data.GetMKLDNNData(out_mem ->get_primitive_desc()); + in_mem = data.GetMKLDNNData(out_mem ->get_desc()); if (in_mem == nullptr) - in_mem = data.GetMKLDNNDataReorder(out_mem->get_primitive_desc()); + in_mem = data.GetMKLDNNDataReorder(out_mem->get_desc()); MKLDNNSum(*out_mem, *in_mem, *out_mem); } else { const_cast(out_data).CopyFrom(*in_mem); diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index 02a7368cce97..6537540fa209 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -20,34 +20,33 @@ /*! * \file mkldnn_deconvolution.cc * \brief - * \author Da Zheng, Rong Zhang (rong.a.zhang@intel.com) -*/ + */ #if MXNET_USE_MKLDNN == 1 #include "../deconvolution-inl.h" -#include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" +#include "./mkldnn_ops-inl.h" namespace mxnet { namespace op { -bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input) { - if (params.kernel.ndim() != 2) - return false; +bool SupportMKLDNNDeconv(const DeconvolutionParam ¶ms, + const NDArray &input) { + if (params.kernel.ndim() != 2) return false; return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; } static inline mkldnn::memory::desc GetBiasDesc(mkldnn::memory::desc md) { mkldnn::memory::dims dims(1); - // This is convolution on 4D data. The second dimension is the channel. + // This is deconvolution on 4D data. The second dimension is the channel. dims[0] = md.data.dims[1]; - return mkldnn::memory::desc(dims, - static_cast(md.data.data_type), - mkldnn::memory::format::any); + return mkldnn::memory::desc( + dims, static_cast(md.data.data_type), + mkldnn::memory::format_tag::any); } -static mkldnn::convolution_forward::primitive_desc GetDeconvBwd_( +std::shared_ptr GetDeconvBwd_( const mkldnn::memory::desc &data_md, const mkldnn::memory::desc &weights_md, bool has_bias, const mkldnn::memory::desc &out_md, const mkldnn::engine &engine, const mkldnn::memory::dims &strides, @@ -58,34 +57,40 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwd_( // memory size may smaller than what MKL-DNN kernels require. So here we need // select suboptimal kernel for computation according to tensor sizes. if (!has_bias) { - mkldnn::convolution_forward::desc desc(mkldnn::prop_kind::forward_training, - mkldnn::algorithm::convolution_direct, out_md, weights_md, data_md, strides, - dilates, padding, padding, mkldnn::padding_kind::zero); - auto deconv_pd = mkldnn::convolution_forward::primitive_desc(desc, engine); - while (deconv_pd.dst_primitive_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd.src_primitive_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd.weights_primitive_desc().get_size() != GetMemDescSize(weights_md)) { - CHECK(deconv_pd.next_impl()) << "No implementation"; + mkldnn::convolution_forward::desc desc( + mkldnn::prop_kind::forward_training, + mkldnn::algorithm::convolution_direct, out_md, weights_md, data_md, + strides, dilates, padding, padding); + auto deconv_pd = + std::make_shared(desc, + engine); + while (deconv_pd->dst_desc().get_size() != GetMemDescSize(data_md) || + deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) || + deconv_pd->weights_desc().get_size() != GetMemDescSize(weights_md)) { + CHECK(deconv_pd->next_impl()) << "No implementation"; } return deconv_pd; } else { auto bias_md = GetBiasDesc(data_md); - mkldnn::convolution_forward::desc desc(mkldnn::prop_kind::forward_training, + mkldnn::convolution_forward::desc desc( + mkldnn::prop_kind::forward_training, mkldnn::algorithm::convolution_direct, out_md, weights_md, bias_md, - data_md, strides, dilates, padding, padding, mkldnn::padding_kind::zero); - auto deconv_pd = mkldnn::convolution_forward::primitive_desc(desc, engine); - while (deconv_pd.dst_primitive_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd.src_primitive_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd.weights_primitive_desc().get_size() != GetMemDescSize(weights_md)) { - CHECK(deconv_pd.next_impl()) << "No implementation"; + data_md, strides, dilates, padding, padding); + auto deconv_pd = + std::make_shared(desc, + engine); + while (deconv_pd->dst_desc().get_size() != GetMemDescSize(data_md) || + deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) || + deconv_pd->weights_desc().get_size() != GetMemDescSize(weights_md)) { + CHECK(deconv_pd->next_impl()) << "No implementation"; } return deconv_pd; } } -static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl( - const DeconvolutionParam& param, const NDArray &data, const NDArray &weights, - bool has_bias, const NDArray &output) { +std::shared_ptr +GetDeconvFwdImpl(const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, bool has_bias, const NDArray &output) { auto data_md = GetMemDesc(data); auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); @@ -103,27 +108,30 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl( dilate[0] = param.dilate[0] - 1; dilate[1] = param.dilate[1] - 1; auto bwd_pd = GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, - strides, padding, dilate); - mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, - out_md, weight_md, data_md, strides, dilate, padding, padding, - mkldnn::padding_kind::zero); - auto deconv_pd = mkldnn::convolution_backward_data::primitive_desc(desc, engine, bwd_pd); + strides, padding, dilate); + mkldnn::convolution_backward_data::desc desc( + mkldnn::algorithm::convolution_direct, out_md, weight_md, data_md, + strides, dilate, padding, padding); + auto deconv_pd = + std::make_shared( + desc, engine, *bwd_pd); // MKL-DNN introduced padded formats since 0.15 which require more memory // for computation compared with the actual tensor size. Currently, MKL-DNN // operators are still reusing those memory from memory planning and the // memory size may smaller than what MKL-DNN kernels require. So here we need // select suboptimal kernel for computation according to tensor sizes. - while (deconv_pd.diff_dst_primitive_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd.diff_src_primitive_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd.weights_primitive_desc().get_size() != GetMemDescSize(weight_md)) { - CHECK(deconv_pd.next_impl()) << "No implementation"; + while (deconv_pd->diff_dst_desc().get_size() != GetMemDescSize(data_md) || + deconv_pd->diff_src_desc().get_size() != GetMemDescSize(out_md) || + deconv_pd->weights_desc().get_size() != GetMemDescSize(weight_md)) { + CHECK(deconv_pd->next_impl()) << "No implementation"; } return deconv_pd; } -static mkldnn::convolution_forward::primitive_desc GetDeconvBwdDataImpl( - const DeconvolutionParam ¶m, const NDArray &data, - const NDArray &weights, bool has_bias, const NDArray &output) { +std::shared_ptr +GetDeconvBwdDataImpl(const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, bool has_bias, + const NDArray &output) { auto data_md = GetMemDesc(data); auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); @@ -140,11 +148,11 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdDataImpl( mkldnn::memory::dims dilate{0, 0}; dilate[0] = param.dilate[0] - 1; dilate[1] = param.dilate[1] - 1; - return GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, - strides, padding, dilate); + return GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, strides, + padding, dilate); } -static mkldnn::convolution_backward_weights::primitive_desc +std::shared_ptr GetDeconvBwdWeightsImpl( const DeconvolutionParam ¶m, const NDArray &data, const NDArray &weights, bool has_bias, const NDArray &output, @@ -172,125 +180,64 @@ GetDeconvBwdWeightsImpl( // memory size may smaller than what MKL-DNN kernels require. So here we need // select suboptimal kernel for computation according to tensor sizes. if (!has_bias) { - mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, - out_md, weight_md, data_md, strides, dilate, padding, padding, mkldnn::padding_kind::zero); - auto deconv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); - while (deconv_pd.diff_dst_primitive_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd.src_primitive_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd.diff_weights_primitive_desc().get_size() != GetMemDescSize(weight_md)) { - CHECK(deconv_pd.next_impl()) << "No implementation"; + mkldnn::convolution_backward_weights::desc desc( + mkldnn::algorithm::convolution_direct, out_md, weight_md, data_md, + strides, dilate, padding, padding); + auto deconv_pd = + std::make_shared( + desc, engine, fwd_pd); + while (deconv_pd->diff_dst_desc().get_size() != GetMemDescSize(data_md) || + deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) || + deconv_pd->diff_weights_desc().get_size() != + GetMemDescSize(weight_md)) { + CHECK(deconv_pd->next_impl()) << "No implementation"; } return deconv_pd; } else { auto bias_md = GetBiasDesc(data_md); - mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, - out_md, weight_md, bias_md, data_md, strides, dilate, padding, padding, - mkldnn::padding_kind::zero); - auto deconv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); - while (deconv_pd.diff_dst_primitive_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd.src_primitive_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd.diff_weights_primitive_desc().get_size() != GetMemDescSize(weight_md)) { - CHECK(deconv_pd.next_impl()) << "No implementation"; + mkldnn::convolution_backward_weights::desc desc( + mkldnn::algorithm::convolution_direct, out_md, weight_md, bias_md, + data_md, strides, dilate, padding, padding); + auto deconv_pd = + std::make_shared( + desc, engine, fwd_pd); + while (deconv_pd->diff_dst_desc().get_size() != GetMemDescSize(data_md) || + deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) || + deconv_pd->diff_weights_desc().get_size() != + GetMemDescSize(weight_md)) { + CHECK(deconv_pd->next_impl()) << "No implementation"; } return deconv_pd; } } class MKLDNNDeconvForward { - std::shared_ptr fwd; - std::shared_ptr data; - std::shared_ptr weight; - std::shared_ptr bias; - std::shared_ptr out; - OutDataOp data_op; - public: - MKLDNNDeconvForward(const DeconvolutionParam& param, - const NDArray &data, - const NDArray &weights, - bool has_bias, + MKLDNNDeconvForward(const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, bool has_bias, const NDArray &output); - void SetDataHandle(const DeconvolutionParam& param, - const OpContext &ctx, - const NDArray &in_data, - const NDArray &weight, - const std::vector &req, - const std::vector &out_data); + const mkldnn::convolution_backward_data &GetFwd() const { return *fwd; } - void Execute(const std::vector &out_data); + const mkldnn::convolution_backward_data::primitive_desc &GetPd() const { + return *fwd_pd; + } private: - mkldnn::convolution_backward_data::primitive_desc fwd_pd; + std::shared_ptr fwd; + std::shared_ptr fwd_pd; }; // class MKLDNNDeconvForward -MKLDNNDeconvForward::MKLDNNDeconvForward(const DeconvolutionParam& param, - const NDArray &data, - const NDArray &weights, - bool has_bias, - const NDArray &output) - :fwd_pd(GetDeconvFwdImpl(param, data, weights, has_bias, output)) { - this->data = std::shared_ptr(new mkldnn::memory( - fwd_pd.diff_dst_primitive_desc())); - this->weight = std::shared_ptr(new mkldnn::memory( - fwd_pd.weights_primitive_desc())); - this->out = std::shared_ptr(new mkldnn::memory( - fwd_pd.diff_src_primitive_desc())); - this->fwd = std::shared_ptr( - new mkldnn::convolution_backward_data(fwd_pd, - mkldnn::primitive::at(*this->data), - mkldnn::primitive::at(*this->weight), - *this->out)); +MKLDNNDeconvForward::MKLDNNDeconvForward(const DeconvolutionParam ¶m, + const NDArray &data, + const NDArray &weights, bool has_bias, + const NDArray &output) + : fwd_pd(GetDeconvFwdImpl(param, data, weights, has_bias, output)) { + fwd = std::make_shared(GetPd()); } -void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param, - const OpContext &ctx, - const NDArray &in_data, - const NDArray &weight, - const std::vector &req, - const std::vector &out_data) { - auto data_mem = in_data.GetMKLDNNDataReorder( - fwd_pd.diff_dst_primitive_desc()); - const mkldnn::memory *weight_mem; - if (ctx.is_train) { - // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it - // to the default format for now. - if (weight.IsMKLDNNData()) - // This asks the engine to reorder data after the weight array is used. - const_cast(weight).Reorder2DefaultAsync(); - weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group); - } else { - // For inference, we want to reorder the weight array so we don't need to - // reorder data every time. - if (weight.IsDefaultData()) { - // We also need to modify the layout on the original weight array. - // Don't switch below sequence because naive engine will executes - // pushAsync synchronously. - const_cast(weight).MKLDNNDataReorderAsync(fwd_pd.weights_primitive_desc()); - weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group); - } else { - weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc()); - } - } - auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut], - fwd_pd.diff_src_primitive_desc(), req[deconv::kOut]); - auto output = out_mem.second; - this->data->set_data_handle(data_mem->get_data_handle()); - this->weight->set_data_handle(weight_mem->get_data_handle()); - this->out->set_data_handle(output->get_data_handle()); - this->data_op = out_mem.first; -} - -void MKLDNNDeconvForward::Execute(const std::vector &out_data) { - MKLDNNStream::Get()->RegisterPrim(*fwd); - CommitOutput(out_data[deconv::kOut], mkldnn_output_t(this->data_op, this->out.get())); - MKLDNNStream::Get()->Submit(); -} - -static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param, - const OpContext &ctx, - const NDArray &bias, - const std::vector &out_data) { +static void MKLDNNDeconvFwdBiasPostProcess( + const DeconvolutionParam ¶m, const OpContext &ctx, const NDArray &bias, + const std::vector &out_data) { // add bias, broadcast bias to dim 1: channel if (!param.no_bias) { // MKLDNN only supports float right now. @@ -306,18 +253,19 @@ static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param, } } -static inline MKLDNNDeconvForward &GetDeconvFwd( - const nnvm::NodeAttrs& attrs, const NDArray &data, - const NDArray &weights, const NDArray *bias, - const NDArray &output) { +MKLDNNDeconvForward &GetDeconvFwd(const nnvm::NodeAttrs &attrs, + const NDArray &data, const NDArray &weights, + const NDArray *bias, const NDArray &output) { #if DMLC_CXX11_THREAD_LOCAL - static thread_local - std::unordered_map fwds; + static thread_local std::unordered_map + fwds; #else static MX_THREAD_LOCAL - std::unordered_map fwds; + std::unordered_map + fwds; #endif - const DeconvolutionParam& param = nnvm::get(attrs.parsed); + const DeconvolutionParam ¶m = nnvm::get(attrs.parsed); DeconvSignature key(param); // Here we can sign the conv op with NDArray because conv primitive will // decide the right layout for the, so we only need to get the shape and the @@ -325,82 +273,95 @@ static inline MKLDNNDeconvForward &GetDeconvFwd( key.AddSign(data); key.AddSign(weights); key.AddSign(output); - if (bias) - key.AddSign(*bias); + if (bias) key.AddSign(*bias); auto it = fwds.find(key); if (it == fwds.end()) { bool has_bias = (bias != nullptr); - MKLDNNDeconvForward fwd(param, data, weights, has_bias, output); + auto fwd = MKLDNNDeconvForward(param, data, weights, has_bias, output); it = AddToCache(&fwds, key, fwd); } return it->second; } -void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, +void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data) { TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); - const DeconvolutionParam& param = nnvm::get(attrs.parsed); + const DeconvolutionParam ¶m = nnvm::get(attrs.parsed); - auto data = in_data[deconv::kData]; - if (data.IsView() && data.IsMKLDNNData()) - data = data.Reorder2Default(); + auto &data = in_data[deconv::kData]; + auto &weight = in_data[deconv::kWeight]; + const NDArray *bias = param.no_bias ? nullptr : &in_data[deconv::kBias]; - auto weight = in_data[deconv::kWeight]; - if (weight.IsView() && weight.IsMKLDNNData()) - weight = weight.Reorder2Default(); + MKLDNNDeconvForward &fwd = + GetDeconvFwd(attrs, data, weight, bias, out_data[deconv::kOut]); - const NDArray* bias = param.no_bias ? nullptr : &in_data[deconv::kBias]; + auto data_mem = data.GetMKLDNNDataReorder(fwd.GetPd().diff_dst_desc()); + const mkldnn::memory *weight_mem; + if (ctx.is_train) { + // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it + // to the default format for now. + if (weight.IsMKLDNNData()) + // This asks the engine to change the layout of the weight array after + // it's used. + weight.Reorder2DefaultAsync(); + weight_mem = + GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group); + } else { + // For inference, we want to reorder the weight array so we don't need to + // reorder data every time. + if (weight.IsDefaultData()) { + // We also need to modify the layout on the original weight array. The + // data conversion happens after the weight array is used. + weight.MKLDNNDataReorderAsync(fwd.GetPd().weights_desc()); + weight_mem = + GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group); - MKLDNNDeconvForward &deconvFwd = GetDeconvFwd( - attrs, data, weight, bias, out_data[deconv::kOut]); + } else { + weight_mem = weight.GetMKLDNNData(); + CHECK(weight_mem->get_desc() == fwd.GetPd().weights_desc()); + } + } + mkldnn_output_t out_mem; + out_mem = CreateMKLDNNMem(out_data[deconv::kOut], fwd.GetPd().diff_src_desc(), + req[deconv::kOut]); - deconvFwd.SetDataHandle(param, ctx, data, weight, req, out_data); + mkldnn_args_map_t net_args; - deconvFwd.Execute(out_data); + net_args.insert({MKLDNN_ARG_DIFF_DST, *data_mem}); + net_args.insert({MKLDNN_ARG_WEIGHTS, *weight_mem}); + net_args.insert({MKLDNN_ARG_DIFF_SRC, *out_mem.second}); + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); + CommitOutput(out_data[deconv::kOut], out_mem); + MKLDNNStream::Get()->Submit(); MKLDNNDeconvFwdBiasPostProcess(param, ctx, *bias, out_data); } class MKLDNNDeconvBackwardData { std::shared_ptr bwd; - std::shared_ptr data; - std::shared_ptr weight; - std::shared_ptr out; public: - const mkldnn::convolution_forward::primitive_desc pd; - + std::shared_ptr bwd_pd; MKLDNNDeconvBackwardData(const DeconvolutionParam ¶m, const NDArray &data, - const NDArray &weights, const NDArray &output) - : pd(GetDeconvBwdDataImpl(param, data, weights, false, output)) { - } - - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, - const mkldnn::memory &output) { - if (bwd == nullptr) { - this->data = std::shared_ptr( - new mkldnn::memory(pd.src_primitive_desc(), data.get_data_handle())); - this->weight = std::shared_ptr( - new mkldnn::memory(pd.weights_primitive_desc(), weight.get_data_handle())); - this->out = std::shared_ptr( - new mkldnn::memory(pd.dst_primitive_desc(), output.get_data_handle())); - bwd = std::shared_ptr( - new mkldnn::convolution_forward(pd, mkldnn::primitive::at(*this->data), - mkldnn::primitive::at(*this->weight), - *this->out)); - } else { - this->data->set_data_handle(data.get_data_handle()); - this->weight->set_data_handle(weight.get_data_handle()); - this->out->set_data_handle(output.get_data_handle()); - } - } + const NDArray &weights, const NDArray &output); const mkldnn::convolution_forward &GetBwd() const { return *bwd; } + const mkldnn::convolution_forward::primitive_desc &GetDataPd() const { + return *bwd_pd; + } }; +MKLDNNDeconvBackwardData::MKLDNNDeconvBackwardData( + const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray &output) + : bwd_pd(GetDeconvBwdDataImpl(param, data, weights, false, output)) { + bwd = std::make_shared(GetDataPd()); +} + typedef ParamOpSign MKLDNNDeconvSignature; static inline MKLDNNDeconvBackwardData &GetDeconvBwdData( @@ -425,7 +386,7 @@ static inline MKLDNNDeconvBackwardData &GetDeconvBwdData( auto it = bwds.find(key); if (it == bwds.end()) { - MKLDNNDeconvBackwardData bwd(param, data, weights, output); + auto bwd = MKLDNNDeconvBackwardData(param, data, weights, output); it = AddToCache(&bwds, key, bwd); } return it->second; @@ -433,44 +394,30 @@ static inline MKLDNNDeconvBackwardData &GetDeconvBwdData( class MKLDNNDeconvBackwardWeights { std::shared_ptr bwd; - std::shared_ptr data; - std::shared_ptr weight; - std::shared_ptr out; public: - const mkldnn::convolution_backward_weights::primitive_desc pd; - + std::shared_ptr + bwd_data_pd; MKLDNNDeconvBackwardWeights( const DeconvolutionParam ¶m, const NDArray &data, const NDArray &weights, const NDArray &output, - const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) - : pd(GetDeconvBwdWeightsImpl(param, data, weights, false, output, - bwd_data_pd)) {} - - void SetNewMem( - const mkldnn::memory &data, const mkldnn::memory &weight, - const mkldnn::memory &output, - const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) { - if (bwd == nullptr) { - this->data = std::shared_ptr(new mkldnn::memory( - bwd_data_pd.src_primitive_desc(), data.get_data_handle())); - this->weight = std::shared_ptr(new mkldnn::memory( - bwd_data_pd.weights_primitive_desc(), weight.get_data_handle())); - this->out = std::shared_ptr(new mkldnn::memory( - bwd_data_pd.dst_primitive_desc(), output.get_data_handle())); - bwd = std::shared_ptr( - new mkldnn::convolution_backward_weights(pd, *this->data, - *this->weight, *this->out)); - } else { - this->data->set_data_handle(data.get_data_handle()); - this->weight->set_data_handle(weight.get_data_handle()); - this->out->set_data_handle(output.get_data_handle()); - } - } - + const mkldnn::convolution_forward::primitive_desc &bwd_data_pd); const mkldnn::convolution_backward_weights &GetBwd() const { return *bwd; } + const mkldnn::convolution_backward_weights::primitive_desc &GetWeightsPd() + const { + return *bwd_data_pd; + } }; +MKLDNNDeconvBackwardWeights::MKLDNNDeconvBackwardWeights( + const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray &output, + const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) + : bwd_data_pd(GetDeconvBwdWeightsImpl(param, data, weights, false, output, + bwd_data_pd)) { + bwd = std::make_shared(GetWeightsPd()); +} + static inline MKLDNNDeconvBackwardWeights &GetDeconvBwdWeights( const DeconvolutionParam ¶m, const NDArray &data, const NDArray &weights, const NDArray &output, @@ -494,7 +441,8 @@ static inline MKLDNNDeconvBackwardWeights &GetDeconvBwdWeights( auto it = bwds.find(key); if (it == bwds.end()) { - MKLDNNDeconvBackwardWeights bwd(param, data, weights, output, bwd_data_pd); + auto bwd = + MKLDNNDeconvBackwardWeights(param, data, weights, output, bwd_data_pd); auto ins_ret = bwds.insert( std::pair(key, bwd)); @@ -513,47 +461,50 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs, const std::vector &in_grad = outputs; const DeconvolutionParam ¶m = nnvm::get(attrs.parsed); - auto data = inputs[deconv::kData + 1]; - if (data.IsView() && data.IsMKLDNNData()) - data = data.Reorder2Default(); - - auto weight = inputs[deconv::kWeight + 1]; - if (weight.IsView() && weight.IsMKLDNNData()) - weight = weight.Reorder2Default(); + auto &data = inputs[deconv::kData + 1]; + auto &weight = inputs[deconv::kWeight + 1]; + auto &out_grad = inputs[deconv::kOut]; CHECK_NE(req[deconv::kWeight], kWriteInplace) << "cannot write weight inplace"; MKLDNNDeconvBackwardData &bwd_data = GetDeconvBwdData(param, data, weight, inputs[deconv::kOut]); - auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder( - bwd_data.pd.src_primitive_desc()); + auto out_grad_mem = + out_grad.GetMKLDNNDataReorder(bwd_data.GetDataPd().src_desc()); if (req[deconv::kData]) { - auto weight_mem = - GetWeights(weight, bwd_data.pd.weights_primitive_desc(), param.num_group); + auto weight_mem = GetWeights(weight, bwd_data.GetDataPd().weights_desc(), + param.num_group); auto in_grad_mem = - CreateMKLDNNMem(in_grad[deconv::kData], - bwd_data.pd.dst_primitive_desc(), req[deconv::kData]); - bwd_data.SetNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second); - MKLDNNStream::Get()->RegisterPrim(bwd_data.GetBwd()); + CreateMKLDNNMem(in_grad[deconv::kData], bwd_data.GetDataPd().dst_desc(), + req[deconv::kData]); + mkldnn_args_map_t net_args = {{MKLDNN_ARG_SRC, *out_grad_mem}, + {MKLDNN_ARG_WEIGHTS, *weight_mem}, + {MKLDNN_ARG_DST, *in_grad_mem.second}}; + MKLDNNStream::Get()->RegisterPrimArgs(bwd_data.GetBwd(), net_args); CommitOutput(in_grad[deconv::kData], in_grad_mem); } if (req[deconv::kWeight]) { MKLDNNDeconvBackwardWeights &bwd_weights = GetDeconvBwdWeights( - param, data, weight, - inputs[deconv::kOut], bwd_data.pd); - if (bwd_data.pd.src_primitive_desc() != bwd_weights.pd.src_primitive_desc()) - out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder( - bwd_weights.pd.src_primitive_desc()); - auto data_mem = data.GetMKLDNNDataReorder( - bwd_weights.pd.diff_dst_primitive_desc()); + param, data, weight, inputs[deconv::kOut], bwd_data.GetDataPd()); + if (bwd_data.GetDataPd().src_desc() != + bwd_weights.GetWeightsPd().src_desc()) + out_grad_mem = + out_grad.GetMKLDNNDataReorder(bwd_weights.GetWeightsPd().src_desc()); + auto data_mem = + data.GetMKLDNNDataReorder(bwd_weights.GetWeightsPd().diff_dst_desc()); auto in_grad_weight = CreateMKLDNNWeightGrad( - in_grad[deconv::kWeight], bwd_weights.pd.diff_weights_primitive_desc(), - req[deconv::kWeight]); - bwd_weights.SetNewMem(*out_grad_mem, *data_mem, *in_grad_weight.second, bwd_data.pd); - MKLDNNStream::Get()->RegisterPrim(bwd_weights.GetBwd()); + in_grad[deconv::kWeight], + bwd_weights.GetWeightsPd().diff_weights_desc(), req[deconv::kWeight]); + + mkldnn_args_map_t net_args = { + {MKLDNN_ARG_SRC, *out_grad_mem}, + {MKLDNN_ARG_DIFF_DST, *data_mem}, + {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second}}; + MKLDNNStream::Get()->RegisterPrimArgs(bwd_weights.GetBwd(), net_args); CommitOutput(in_grad[deconv::kWeight], in_grad_weight); } MKLDNNStream::Get()->Submit(); + if (!param.no_bias) { typedef float DType; Stream *s = ctx.get_stream(); @@ -573,5 +524,4 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs, } // namespace op } // namespace mxnet - #endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h index fddaedc2459d..7d64cf5a92a7 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h @@ -50,7 +50,7 @@ struct MKLDNNFCParam: public dmlc::Parameter { DMLC_DECLARE_FIELD(enable_float_output).set_default(false) .describe("Whether to enable float32 output"); DMLC_DECLARE_FIELD(with_eltwise).set_default(false) - .describe("Whether there's a post elemwise after FullyConnected operator"); + .describe("Whether there's a post with_eltwise after FullyConnected operator"); DMLC_DECLARE_FIELD(min_calib_range) .set_default(dmlc::optional()) .describe("The minimum scalar value in the form of float32 obtained " @@ -85,10 +85,9 @@ class MKLDNNFullyConnectedForward { const NDArray &data, const NDArray &weight, const NDArray *bias, const mkldnn::memory::desc &out_md) - : fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {} - - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, - const mkldnn::memory *bias, const mkldnn::memory &output); + : fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) { + fwd_ = std::make_shared(fwd_pd); + } const mkldnn::inner_product_forward &GetFwd() const { return *fwd_; @@ -96,10 +95,6 @@ class MKLDNNFullyConnectedForward { private: std::shared_ptr fwd_; - std::shared_ptr data_; - std::shared_ptr weight_; - std::shared_ptr bias_; - std::shared_ptr out_; }; typedef ParamOpSign MKLDNNFullyconSignature; diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index fbe37e227cd1..1e7f879c5322 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -37,7 +37,7 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl( const NDArray &data, const NDArray &weight, const NDArray *bias, const mkldnn::memory::desc &out_md) { auto data_md = GetMemDesc(data); - auto weight_md = GetMemDesc(weight); + auto weight_md = GetFCWeightDesc(weight); auto engine = CpuEngine::Get()->get_engine(); auto propagation = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; @@ -67,7 +67,6 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl( } attr.set_output_scales(mask, scales); - attr.set_int_output_round_mode(round_nearest); } } @@ -102,7 +101,7 @@ inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData( const NDArray &data, const NDArray &weight, const NDArray &output, mkldnn::inner_product_forward::primitive_desc fwd_pd) { auto data_md = GetMemDesc(data); - auto weight_md = GetMemDesc(weight); + auto weight_md = GetFCWeightDesc(weight); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md); @@ -113,7 +112,7 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei const NDArray &data, const NDArray &weight, const NDArray *bias, const NDArray &output, mkldnn::inner_product_forward::primitive_desc fwd_pd) { auto data_md = GetMemDesc(data); - auto weight_md = GetMemDesc(weight); + auto weight_md = GetFCWeightDesc(weight); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); if (bias) { @@ -130,51 +129,6 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei } } -void MKLDNNFullyConnectedForward::SetNewMem(const mkldnn::memory &data, - const mkldnn::memory &weight, - const mkldnn::memory *bias, - const mkldnn::memory &output) { - if (this->data_ == nullptr) - this->data_ = std::shared_ptr(new mkldnn::memory( - fwd_pd.src_primitive_desc(), data.get_data_handle())); - else - this->data_->set_data_handle(data.get_data_handle()); - - if (this->weight_ == nullptr) - this->weight_ = std::shared_ptr(new mkldnn::memory( - fwd_pd.weights_primitive_desc(), weight.get_data_handle())); - else - this->weight_->set_data_handle(weight.get_data_handle()); - - if (this->out_ == nullptr) - this->out_ = std::shared_ptr(new mkldnn::memory( - fwd_pd.dst_primitive_desc(), output.get_data_handle())); - else - this->out_->set_data_handle(output.get_data_handle()); - - if (bias != nullptr) { - if (this->bias_ == nullptr) - this->bias_ = std::shared_ptr(new mkldnn::memory( - fwd_pd.bias_primitive_desc(), bias->get_data_handle())); - else - this->bias_->set_data_handle(bias->get_data_handle()); - - if (this->fwd_ == nullptr) - this->fwd_ = std::shared_ptr( - new mkldnn::inner_product_forward( - fwd_pd, mkldnn::primitive::at(*this->data_), - mkldnn::primitive::at(*this->weight_), - mkldnn::primitive::at(*this->bias_), *this->out_)); - } else { - if (this->fwd_ == nullptr) { - this->fwd_ = std::shared_ptr( - new mkldnn::inner_product_forward( - fwd_pd, mkldnn::primitive::at(*this->data_), - mkldnn::primitive::at(*this->weight_), *this->out_)); - } - } -} - MKLDNNFullyConnectedForward &GetFCFwd( const FullyConnectedParam ¶m, const bool is_train, const NDArray &data, const NDArray &weight, @@ -223,13 +177,13 @@ void MKLDNNFCFlattenData(const FullyConnectedParam ¶m, mkldnn::memory::dims out_dims{static_cast(oshape.ProdShape(0, oshape.ndim()-1)), static_cast(oshape[ishape.ndim()-1])}; *out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()), - mkldnn::memory::format::any); + mkldnn::memory::format_tag::any); } else { *in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim()))); mkldnn::memory::dims out_dims{static_cast(oshape[0]), static_cast(oshape.ProdShape(1, oshape.ndim()))}; *out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()), - mkldnn::memory::format::any); + mkldnn::memory::format_tag::any); } } } @@ -244,35 +198,34 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param, NDArray weight = in_data[fullc::kWeight]; NDArray data = in_data[fullc::kData]; - auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_primitive_desc()); + auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_desc()); const mkldnn::memory *weight_mem; if (ctx.is_train) { if (weight.IsMKLDNNData()) { weight.Reorder2DefaultAsync(); } - weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1); + weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1); } else { - if (weight.IsDefaultData()) { - // We also need to modify the layout on the original weight array. - // Don't switch below sequence because naive engine will executes - // pushAsync synchronously. - weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc()); - weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1); - } else { - weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc()); + weight_mem = weight.GetMKLDNNData(); + if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) { + weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc()); + weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1); } } auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], - fwd->fwd_pd.dst_primitive_desc(), req[fullc::kOut], &data); + fwd->fwd_pd.dst_desc(), req[fullc::kOut], &data); + + mkldnn_args_map_t args = { + {MKLDNN_ARG_SRC, *data_mem}, + {MKLDNN_ARG_WEIGHTS, *weight_mem}, + {MKLDNN_ARG_DST, *out_mem.second}, + }; if (!full_param.default_param.no_bias) { auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder( - fwd->fwd_pd.bias_primitive_desc()); - fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); - } else { - fwd->SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second); + fwd->fwd_pd.bias_desc()); + args[MKLDNN_ARG_BIAS] = *bias_mem; } - MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd()); + MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args); CommitOutput(out_data[fullc::kOut], out_mem); MKLDNNStream::Get()->Submit(); } @@ -339,13 +292,18 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData( data, weight, out_grad, fwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - ipBwdData_pd.diff_dst_primitive_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc()); + ipBwdData_pd.diff_dst_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc()); auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], - ipBwdData_pd.diff_src_primitive_desc(), + ipBwdData_pd.diff_src_desc(), req[fullc::kData]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data( - ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second)); + mkldnn_args_map_t args = { + {MKLDNN_ARG_DIFF_DST, *out_grad_mem}, + {MKLDNN_ARG_WEIGHTS, *weight_mem}, + {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second} + }; + + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args); CommitOutput(in_grad[fullc::kData], in_grad_mem); } if (req[fullc::kWeight]) { @@ -353,23 +311,26 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, = GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], out_grad, fwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - ipBwdWeights_pd.diff_dst_primitive_desc()); - auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc()); + ipBwdWeights_pd.diff_dst_desc()); + auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_desc()); auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight], - ipBwdWeights_pd.diff_weights_primitive_desc(), + ipBwdWeights_pd.diff_weights_desc(), req[fullc::kWeight]); + mkldnn_args_map_t args = { + {MKLDNN_ARG_DIFF_DST, *out_grad_mem}, + {MKLDNN_ARG_SRC, *data_mem}, + {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second}, + }; + mkldnn_output_t in_grad_bias; - if (param.no_bias) { - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( - ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); - } else { + if (!param.no_bias) { in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], - ipBwdWeights_pd.diff_bias_primitive_desc(), + ipBwdWeights_pd.diff_bias_desc(), req[fullc::kBias]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( - ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, - *in_grad_bias.second)); + args[MKLDNN_ARG_DIFF_BIAS] = *in_grad_bias.second; } + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::inner_product_backward_weights(ipBwdWeights_pd), args); CommitOutput(in_grad[fullc::kWeight], in_grad_weight); CommitOutput(in_grad[fullc::kBias], in_grad_bias); } diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h index 31b293a14c2c..ca7095fd3f02 100644 --- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h @@ -34,27 +34,27 @@ namespace mxnet { namespace op { -inline algorithm GetMKLDNNLRNAlgo(const LRNParam ¶m) { +inline mkldnn::algorithm GetMKLDNNLRNAlgo(const LRNParam ¶m) { // TODO(Patric): lrn_within_channel will cause core dump in MKLDNN backward // Need to confirm with MKLDNN team and fix later - return algorithm::lrn_across_channels; + return mkldnn::algorithm::lrn_across_channels; } inline mkldnn::lrn_forward::primitive_desc GetLRNFwdDesc( - const LRNParam ¶m, const bool is_train, const memory::desc &src_md) { + const LRNParam ¶m, const bool is_train, const mkldnn::memory::desc &src_md) { mkldnn::engine &engine = CpuEngine::Get()->get_engine(); - const algorithm alg = GetMKLDNNLRNAlgo(param); + const mkldnn::algorithm alg = GetMKLDNNLRNAlgo(param); const float alpha = param.alpha; const float beta = param.beta; const int nsize = param.nsize; const float k = param.knorm; - auto kind = prop_kind::forward_training; + auto kind = mkldnn::prop_kind::forward_training; if (is_train) { - kind = prop_kind::forward_training; + kind = mkldnn::prop_kind::forward_training; } else { - kind = prop_kind::forward_scoring; + kind = mkldnn::prop_kind::forward_scoring; } - lrn_forward::desc fwd_desc(kind, alg, src_md, nsize, alpha, beta, k); + mkldnn::lrn_forward::desc fwd_desc(kind, alg, src_md, nsize, alpha, beta, k); return mkldnn::lrn_forward::primitive_desc(fwd_desc, engine); } @@ -63,13 +63,13 @@ inline mkldnn::lrn_backward::primitive_desc GetLRNBwdDesc( const mkldnn::memory::desc &diff_md, const mkldnn::lrn_forward::primitive_desc &lrnFwd_desc) { mkldnn::engine &engine = CpuEngine::Get()->get_engine(); - const algorithm alg = GetMKLDNNLRNAlgo(param); + const mkldnn::algorithm alg = GetMKLDNNLRNAlgo(param); const float alpha = param.alpha; const float beta = param.beta; const int nsize = param.nsize; const float k = param.knorm; - lrn_backward::desc lrnBwd_desc(alg, data_in_md, + mkldnn::lrn_backward::desc lrnBwd_desc(alg, data_in_md, diff_md, nsize, alpha, beta, k); return mkldnn::lrn_backward::primitive_desc(lrnBwd_desc, engine, lrnFwd_desc); @@ -83,33 +83,24 @@ class MKLDNNLRNFwd { public: MKLDNNLRNFwd(const LRNParam& param, bool is_train, - const NDArray &in_data): - is_train(is_train) { + const NDArray &in_data) { _Init(param, is_train, in_data); } ~MKLDNNLRNFwd() {} - void SetNewMem(const NDArray &data, - const NDArray &output, - const OpReqType req); - - void SetNewMem(const NDArray &in_data, - const mkldnn::memory *out_mem); - - void Execute(const NDArray &out_data); + void Execute(const OpContext &ctx, + const NDArray &in_data, + const OpReqType req, + const NDArray &out_data); mkldnn::lrn_forward &GetFwd(); - const mkldnn::memory *GetWs(); + mkldnn::lrn_forward::primitive_desc &GetFwdPd(); private: std::shared_ptr fwd; - std::shared_ptr in_mem; - std::shared_ptr out_mem; - std::shared_ptr ws_mem; - mkldnn_output_t output_mem_t; - bool is_train; + mkldnn::lrn_forward::primitive_desc fwd_pd; private: void _Init(const LRNParam ¶m, bool is_train, const NDArray &in_data); @@ -119,52 +110,37 @@ void MKLDNNLRNFwd::_Init(const LRNParam ¶m, bool is_train, const NDArray &in_data) { mkldnn::memory::desc in_data_md = - in_data.GetMKLDNNData()->get_primitive_desc().desc(); - mkldnn::lrn_forward::primitive_desc fwd_pd = + in_data.GetMKLDNNData()->get_desc(); + this->fwd_pd = GetLRNFwdDesc(param, is_train, in_data_md); - this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData() - ->get_primitive_desc())); - this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc())); - if (is_train) { - // If it's training, we have to create a workspace memory. Otherwise, - // MKLDNN will have segmentation fault. - ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc())); - this->fwd = std::shared_ptr( - new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem), - *this->ws_mem, *this->out_mem)); - } else { - this->fwd = std::shared_ptr( - new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*(this->in_mem)), - *(this->out_mem))); - } -} - -void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data, - const NDArray &out_data, - const OpReqType req) { - const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData(); - output_mem_t = CreateMKLDNNMem(out_data, this->out_mem->get_primitive_desc(), req); - this->in_mem->set_data_handle(in_data_mem->get_data_handle()); - this->out_mem->set_data_handle(output_mem_t.second->get_data_handle()); + this->fwd = std::shared_ptr(new mkldnn::lrn_forward(this->fwd_pd)); } -void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data, - const mkldnn::memory *out_mem) { - const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData(); - this->in_mem->set_data_handle(in_data_mem->get_data_handle()); - this->out_mem->set_data_handle(out_mem->get_data_handle()); -} - -void MKLDNNLRNFwd::Execute(const NDArray &out_data) { - MKLDNNStream::Get()->RegisterPrim(*(this->fwd)); +void MKLDNNLRNFwd::Execute(const OpContext &ctx, + const NDArray &in_data, + const OpReqType req, + const NDArray &out_data) { + auto output_mem_t = CreateMKLDNNMem(out_data, (this->fwd_pd).dst_desc(), req); + + mkldnn_args_map_t args = { + { MKLDNN_ARG_SRC, *in_data.GetMKLDNNData()}, + { MKLDNN_ARG_DST, *output_mem_t.second }, + }; + std::shared_ptr workspace; + if (ctx.is_train) { + auto engine = CpuEngine::Get()->get_engine(); + workspace = std::make_shared((this->fwd_pd).workspace_desc(), engine); + args[MKLDNN_ARG_WORKSPACE] = *(workspace); + } + MKLDNNStream::Get()->RegisterPrimArgs(*(this->fwd), args); CommitOutput(out_data, output_mem_t); MKLDNNStream::Get()->Submit(); } mkldnn::lrn_forward &MKLDNNLRNFwd::GetFwd() { return *this->fwd; } +mkldnn::lrn_forward::primitive_desc &MKLDNNLRNFwd::GetFwdPd() { return this->fwd_pd; } -const mkldnn::memory *MKLDNNLRNFwd::GetWs() { return this->ws_mem.get(); } // End of LRN Class and its functions static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, @@ -180,10 +156,11 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, OpHash> lrn_fwds; #endif auto kind_ = - ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring; + ctx.is_train ? mkldnn::prop_kind::forward_training + : mkldnn::prop_kind::forward_scoring; MKLDNNLRNSignature key(param); - key.AddSign(kind_); + key.AddSign(static_cast(kind_)); key.AddSign(in_data); auto it = lrn_fwds.find(key); @@ -201,17 +178,12 @@ void MKLDNNLRNForward(const OpContext &ctx, const LRNParam ¶m, if (in_buffer.IsView() && in_buffer.IsMKLDNNData()) in_buffer = in_buffer.Reorder2Default(); MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer); - fwd.SetNewMem(in_buffer, out_data, req); - fwd.Execute(out_data); + fwd.Execute(ctx, in_buffer, req, out_data); } // LRN Backward Class class MKLDNNLRNBwd { std::shared_ptr bwd; - std::shared_ptr in_data_mem; - std::shared_ptr diff_dst_mem; - std::shared_ptr ws_mem; - std::shared_ptr diff_src_mem; public: const mkldnn::lrn_forward::primitive_desc fwd_pd; @@ -222,40 +194,26 @@ class MKLDNNLRNBwd { MKLDNNLRNBwd(const LRNParam ¶m, const mkldnn::memory::desc in_data_md, const mkldnn::memory::desc diff_md) : fwd_pd(GetLRNFwdDesc(param, true, in_data_md)), - bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {} - - void SetNewMem(const NDArray &in_data, const NDArray &out_grad, - const mkldnn::memory *ws, const mkldnn::memory *diff_src_mem) { - if (bwd == nullptr) { - this->in_data_mem.reset( - new mkldnn::memory(this->fwd_pd.src_primitive_desc(), - in_data.GetMKLDNNData()->get_data_handle())); - this->diff_dst_mem.reset( - new mkldnn::memory(this->fwd_pd.dst_primitive_desc(), - out_grad.GetMKLDNNData()->get_data_handle())); - this->ws_mem.reset( - new mkldnn::memory(this->fwd_pd.workspace_primitive_desc(), - ws->get_data_handle())); - this->diff_src_mem.reset( - new mkldnn::memory(this->bwd_pd.diff_src_primitive_desc(), - diff_src_mem->get_data_handle())); - this->bwd.reset(new mkldnn::lrn_backward( - this->bwd_pd, mkldnn::primitive::at(*this->in_data_mem), - mkldnn::primitive::at(*this->diff_dst_mem), *this->ws_mem, - *this->diff_src_mem)); - } else { - this->in_data_mem->set_data_handle( - in_data.GetMKLDNNData()->get_data_handle()); - this->diff_dst_mem->set_data_handle( - out_grad.GetMKLDNNData()->get_data_handle()); - this->ws_mem->set_data_handle(ws->get_data_handle()); - this->diff_src_mem->set_data_handle(diff_src_mem->get_data_handle()); - } - } - - void Execute(const NDArray &in_grad, const mkldnn_output_t &diff_src_mem_) { - MKLDNNStream::Get()->RegisterPrim(*(this->bwd)); - CommitOutput(in_grad, diff_src_mem_); + bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) { + bwd = std::make_shared(bwd_pd); + } + + const mkldnn::lrn_backward &GetBwd() const { return *bwd; } + + void Execute(const NDArray &out_grad, + const NDArray &in_data, + const NDArray &in_grad, + const mkldnn_output_t &diff_src_mem) { + auto engine = CpuEngine::Get()->get_engine(); + auto workspace = std::make_shared((this->fwd_pd).workspace_desc(), engine); + mkldnn_args_map_t args = { + { MKLDNN_ARG_SRC, *in_data.GetMKLDNNData() }, + { MKLDNN_ARG_DIFF_DST, *out_grad.GetMKLDNNData()}, + { MKLDNN_ARG_WORKSPACE, *workspace }, + { MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second } + }; + MKLDNNStream::Get()->RegisterPrimArgs(*(this->bwd), args); + CommitOutput(in_grad, diff_src_mem); MKLDNNStream::Get()->Submit(); } }; // End of LRN Class @@ -277,9 +235,9 @@ static MKLDNNLRNBwd &GetLRNBwd(const LRNParam ¶m, const NDArray &in_data, auto it = lrn_bwds.find(key); if (it == lrn_bwds.end()) { const mkldnn::memory::desc in_data_md = - in_data.GetMKLDNNData()->get_primitive_desc().desc(); + in_data.GetMKLDNNData()->get_desc(); const mkldnn::memory::desc diff_md = - out_grad.GetMKLDNNData()->get_primitive_desc().desc(); + out_grad.GetMKLDNNData()->get_desc(); MKLDNNLRNBwd bwd(param, in_data_md, diff_md); it = AddToCache(&lrn_bwds, key, bwd); } @@ -300,23 +258,13 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m, in_buffer = in_data.Reorder2Default(); } MKLDNNLRNBwd &bwd = GetLRNBwd(param, in_buffer, in_grad, out_grad); - // Repeat FW for getting workspace - // TODO(Patric): To keep the function stateless, we can't pass workspace - // from LRN forward to backward. We have to re-compute - // LRN forward to get the workspace. - // Will refine this code later. - MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer); - std::shared_ptr dst_temp( - new mkldnn::memory(bwd.fwd_pd.dst_primitive_desc())); - fwd.SetNewMem(in_buffer, dst_temp.get()); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); - mkldnn_output_t diff_src_mem = - CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_primitive_desc(), req); - bwd.SetNewMem(in_buffer, out_grad, fwd.GetWs(), diff_src_mem.second); - bwd.Execute(in_grad, diff_src_mem); + CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req); + + bwd.Execute(out_grad, in_buffer, in_grad, diff_src_mem); } } // namespace op } // namespace mxnet #endif // MXNET_USE_MKLDNN == 1 #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H__ + diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index c0218f4100b5..71f3eafa8ee9 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -26,7 +26,6 @@ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_ -#if MXNET_USE_MKLDNN == 1 #include #include @@ -36,6 +35,8 @@ #include #include #include + +#if MXNET_USE_MKLDNN == 1 #include namespace mxnet { @@ -71,6 +72,21 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext & const std::vector& req, const std::vector& outputs); +/* For activation */ +void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); +void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &out_grad, const NDArray &in_data, + const OpReqType &req, const NDArray &in_grad); + +void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); +void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector& inputs, const OpReqType &req, + const NDArray &output); + /* For softmax */ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, @@ -102,20 +118,6 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector& req, const std::vector& outputs); -/* For activation */ -void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &in_data, const OpReqType &req, - const NDArray &out_data); -void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &out_grad, const NDArray &in_data, - const OpReqType &req, const NDArray &in_grad); -void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &in_data, const OpReqType &req, - const NDArray &out_data); -void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector& inputs, const OpReqType &req, - const NDArray &output); - void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, const mkldnn::memory &out); @@ -130,9 +132,8 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, const NDArray &input, const OpReqType &req, const NDArray &output); - } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 1 #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h index 9b9f0193979b..22e9abd156a3 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -43,33 +43,26 @@ class MKLDNNPoolingFwd { const int padding_t, const int padding_b, const int padding_l, const int padding_r, const mkldnn::algorithm alg_kind, - const bool with_workspace, const bool is_train) : - is_train_(is_train), + const bool with_workspace, const bool is_train): with_workspace_(with_workspace), - alg_kind_(alg_kind), - fwd_(nullptr), data_(nullptr), out_(nullptr), workspace_(nullptr) { + fwd_(nullptr) { Init(input, output, kernel_h, kernel_w, stride_h, stride_w, - padding_t, padding_b, padding_l, padding_r); + padding_t, padding_b, padding_l, padding_r, + is_train, alg_kind); } ~MKLDNNPoolingFwd() {} - void SetNewMem(const NDArray& in_data, - const NDArray& out_data, - const OpReqType& req, - const mxnet::NDArray *workspace = nullptr); - void Execute(const NDArray& out_data); + void Execute(const NDArray &in_data, + const OpReqType req, + const NDArray& out_data, + const NDArray *workspace); private: - bool is_train_; bool with_workspace_; - mkldnn::algorithm alg_kind_; + std::shared_ptr fwd_pd_; std::shared_ptr fwd_; - std::shared_ptr data_; - std::shared_ptr out_; - std::shared_ptr workspace_; - mkldnn_output_t output_mem_t_; private: void Init(const mxnet::NDArray &input, @@ -77,26 +70,21 @@ class MKLDNNPoolingFwd { const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int padding_t, const int padding_b, - const int padding_l, const int padding_r); + const int padding_l, const int padding_r, + const bool is_train, const mkldnn::algorithm alg_kind); }; class MKLDNNPoolingBwd { std::shared_ptr bwd; - std::shared_ptr diff_dst; - std::shared_ptr diff_src; - std::shared_ptr ws; bool with_workspace; public: const mkldnn::pooling_backward::primitive_desc pd; - MKLDNNPoolingBwd(const pooling_backward::primitive_desc &pdesc, + MKLDNNPoolingBwd(const mkldnn::pooling_backward::primitive_desc &pdesc, bool with_ws); ~MKLDNNPoolingBwd() {} - void SetNewMem(const mxnet::NDArray *workspace, - const mxnet::NDArray &out_grad, - const mkldnn::memory *diff_src_mem); const mkldnn::pooling_backward &GetBwd(); const mkldnn::pooling_backward::primitive_desc &GetPd(); }; diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index f4d681ded78d..f9dbe5bbfd8f 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -34,18 +34,17 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int padding_t, const int padding_b, - const int padding_l, const int padding_r) { - // mkldnn::memory::desc - auto src_md = input.GetMKLDNNData()->get_primitive_desc().desc(); + const int padding_l, const int padding_r, + const bool is_train, const mkldnn::algorithm alg_kind) { + auto src_md = input.GetMKLDNNData()->get_desc(); mkldnn::memory::dims dims = {src_md.data.dims[0], src_md.data.dims[1], static_cast(output.shape()[2]), static_cast(output.shape()[3])}; auto dst_md = mkldnn::memory::desc({dims}, static_cast(src_md.data.data_type), - static_cast(src_md.data.format)); + mkldnn::memory::format_tag::any); const mkldnn::engine engine = CpuEngine::Get()->get_engine(); - const mkldnn::algorithm alg_kind = this->alg_kind_; if (alg_kind != mkldnn::algorithm::pooling_max && alg_kind != mkldnn::algorithm::pooling_avg && alg_kind != mkldnn::algorithm::pooling_avg_include_padding && @@ -54,10 +53,10 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o } mkldnn::prop_kind prop = mkldnn::prop_kind::forward_scoring; - if (this->is_train_ && alg_kind != mkldnn::algorithm::pooling_avg) { + if (is_train && alg_kind != mkldnn::algorithm::pooling_avg) { prop = mkldnn::prop_kind::forward_training; } - if (this->is_train_ && prop == mkldnn::prop_kind::forward_scoring) { + if (is_train && prop == mkldnn::prop_kind::forward_scoring) { LOG(INFO) << "MKLDNN Pooling: training with prop_kind is forward_scoring"; } @@ -67,49 +66,43 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o const mkldnn::memory::dims kernel = {kernel_h, kernel_w }; // mkldnn::pooling_forward::desc const auto fwd_desc = mkldnn::pooling_forward::desc(prop, alg_kind, src_md, dst_md, - strides, kernel, pad_l, pad_r, - mkldnn::padding_kind::zero); + strides, kernel, pad_l, pad_r); this->fwd_pd_.reset(new mkldnn::pooling_forward::primitive_desc(fwd_desc, engine)); - this->data_.reset(new mkldnn::memory(input.GetMKLDNNData()->get_primitive_desc())); - this->out_.reset(new mkldnn::memory(this->fwd_pd_->dst_primitive_desc())); - if (this->with_workspace_) { - this->workspace_.reset(new mkldnn::memory(this->fwd_pd_->workspace_primitive_desc())); - this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_), - mkldnn::primitive::at(*(this->data_)), - *(this->out_), - *(this->workspace_))); - } else { - this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_), - mkldnn::primitive::at(*(this->data_)), - *(this->out_))); - } + this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_))); + return; } -void MKLDNNPoolingFwd::SetNewMem(const NDArray& in_data, - const NDArray& out_data, - const OpReqType& req, - const mxnet::NDArray *workspace) { - auto input_mem = in_data.GetMKLDNNData(); - output_mem_t_ = CreateMKLDNNMem(out_data, fwd_pd_->dst_primitive_desc(), req); - // mkldnn::memory - this->data_->set_data_handle(input_mem->get_data_handle()); - this->out_->set_data_handle(output_mem_t_.second->get_data_handle()); - if (this->with_workspace_ && workspace == nullptr) { - LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input"; - } +void MKLDNNPoolingFwd::Execute(const NDArray &in_data, + const OpReqType req, + const NDArray& out_data, + const NDArray *workspace) { + NDArray in_buffer = in_data; + if (in_data.IsView() && in_data.IsMKLDNNData()) + in_buffer = in_data.Reorder2Default(); + + auto input_mem = in_buffer.GetMKLDNNData(); + auto output_mem_t_ = CreateMKLDNNMem(out_data, this->fwd_pd_->dst_desc(), req); + + mkldnn_args_map_t args = { + {MKLDNN_ARG_SRC, *input_mem }, + {MKLDNN_ARG_DST, *(output_mem_t_.second) }, + }; if (this->with_workspace_) { - // mkldnn::memory - auto ws_mem = workspace->GetMKLDNNData(); - this->workspace_->set_data_handle(ws_mem->get_data_handle()); - } -} + auto engine = CpuEngine::Get()->get_engine(); + + if (workspace == nullptr) { + LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input"; + } -void MKLDNNPoolingFwd::Execute(const NDArray& out_data) { + auto ws = std::make_shared((*(this->fwd_pd_)).workspace_desc(), + engine, workspace->GetMKLDNNData()->get_data_handle()); + args[MKLDNN_ARG_WORKSPACE] = *ws; + } if (this->fwd_) { - MKLDNNStream::Get()->RegisterPrim(*(this->fwd_)); - CommitOutput(out_data, this->output_mem_t_); + MKLDNNStream::Get()->RegisterPrimArgs(*(this->fwd_), args); + CommitOutput(out_data, output_mem_t_); MKLDNNStream::Get()->Submit(); } else { LOG(FATAL) << "MKLDNN Pooling: forward primitive is nullptr"; @@ -143,8 +136,8 @@ static inline int GetPaddingSizeFull(int x, int padl, int padr, int k, int s) { } mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( - const PoolingParam ¶m, const bool is_train, const memory::desc &data_md, - const memory::desc &out_md) { + const PoolingParam ¶m, const bool is_train, const mkldnn::memory::desc &data_md, + const mkldnn::memory::desc &out_md) { CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; int kernel_h_, kernel_w_; if (param.global_pool) { @@ -183,19 +176,18 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring; - if (is_train && alg != algorithm::pooling_avg) { + if (is_train && alg != mkldnn::algorithm::pooling_avg) { kind = mkldnn::prop_kind::forward_training; } - const pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md, + const mkldnn::pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md, {static_cast(stride_h_), static_cast(stride_w_)}, {kernel_h_, kernel_w_}, {static_cast(pad_t_), static_cast(pad_l_)}, {static_cast(pad_b_), - static_cast(pad_r_)}, - padding_kind::zero); + static_cast(pad_r_)}); return mkldnn::pooling_forward::primitive_desc(poolingFwd_desc, engine); } @@ -223,7 +215,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, auto it = pooling_fwds.find(key); if (it == pooling_fwds.end()) { CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; - auto data_md = data.GetMKLDNNData()->get_primitive_desc().desc(); + auto data_md = data.GetMKLDNNData()->get_desc(); int kernel_h_, kernel_w_; if (param.global_pool) { kernel_h_ = data_md.data.dims[2]; @@ -270,42 +262,14 @@ void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m, const NDArray &in_data, const OpReqType req, const NDArray &out_data, const NDArray *workspace) { auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data); - fwd.SetNewMem(in_data, out_data, req, workspace); - fwd.Execute(out_data); + fwd.Execute(in_data, req, out_data, workspace); } MKLDNNPoolingBwd::MKLDNNPoolingBwd( - const pooling_backward::primitive_desc &pdesc, bool with_ws) - : with_workspace(with_ws), pd(pdesc) {} - -void MKLDNNPoolingBwd::SetNewMem(const mxnet::NDArray *workspace, - const mxnet::NDArray &out_grad, - const mkldnn::memory *diff_src_mem) { - if (bwd == nullptr) { - diff_dst.reset( - new mkldnn::memory(out_grad.GetMKLDNNData()->get_primitive_desc(), - out_grad.GetMKLDNNData()->get_data_handle())); - diff_src.reset(new mkldnn::memory(pd.diff_src_primitive_desc(), - diff_src_mem->get_data_handle())); - if (with_workspace) { - CHECK(workspace != nullptr); - ws.reset( - new mkldnn::memory(workspace->GetMKLDNNData()->get_primitive_desc(), - workspace->GetMKLDNNData()->get_data_handle())); - bwd.reset( - new pooling_backward(pd, *diff_dst, primitive::at(*ws), *diff_src)); - } else { - bwd.reset(new pooling_backward(pd, *diff_dst, *diff_src)); - } - } else { - diff_dst->set_data_handle(out_grad.GetMKLDNNData()->get_data_handle()); - diff_src->set_data_handle(diff_src_mem->get_data_handle()); - if (with_workspace) { - CHECK(workspace != nullptr); - ws->set_data_handle(workspace->GetMKLDNNData()->get_data_handle()); + const mkldnn::pooling_backward::primitive_desc &pdesc, bool with_ws) + : with_workspace(with_ws), pd(pdesc) { + bwd = std::make_shared(pd); } - } -} const mkldnn::pooling_backward &MKLDNNPoolingBwd::GetBwd() { return *this->bwd; @@ -333,27 +297,29 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, auto it = pooling_bwds.find(key); if (it == pooling_bwds.end()) { - auto diff_dst_mem = out_grad.GetMKLDNNData(); + NDArray diff_dst_buff = out_grad; + if (in_data.IsMKLDNNData() == false && diff_dst_buff.IsMKLDNNData() == true) { + diff_dst_buff = out_grad.Reorder2Default(); + } + auto diff_dst_mem = diff_dst_buff.GetMKLDNNData(); auto input_mem = in_data.GetMKLDNNData(); - mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); - const mkldnn::memory::desc data_md = data_mpd.desc(); - const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], + const mkldnn::memory::desc data_md = input_mem->get_desc(); + const mkldnn::memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], static_cast(out_grad.shape()[2]), static_cast(out_grad.shape()[3])}; - const memory::desc out_md( - {dims}, static_cast(data_md.data.data_type), - static_cast(data_md.data.format)); + const mkldnn::memory::desc out_md( + {dims}, static_cast(data_md.data.data_type), + mkldnn::memory::format_tag::any); auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md); - const mkldnn::memory::desc diff_md = - diff_dst_mem->get_primitive_desc().desc(); - const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], + diff_dst_mem->get_desc(); + const mkldnn::memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], static_cast(in_grad.shape()[2]), static_cast(in_grad.shape()[3])}; - const memory::desc diff_in_md( - {dims1}, static_cast(diff_md.data.data_type), - static_cast(diff_md.data.format)); - const mkldnn::engine cpu_engine = data_mpd.get_engine(); + const mkldnn::memory::desc diff_in_md( + {dims1}, static_cast(diff_md.data.data_type), + mkldnn::memory::format_tag::any); + const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();; const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); int kernel_h_, kernel_w_; @@ -379,11 +345,10 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, stride_h_ = stride_w_ = 1; } - const pooling_backward::desc desc( + const mkldnn::pooling_backward::desc desc( alg, diff_in_md, diff_md, {stride_h_, stride_w_}, - {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_}, - mkldnn::padding_kind::zero); - const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd); + {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_}); + const auto pdesc = mkldnn::pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd); MKLDNNPoolingBwd bwd(pdesc, with_workspace); it = AddToCache(&pooling_bwds, key, bwd); } @@ -401,10 +366,17 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m, auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad); auto diff_src_mem = - CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); + CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req); + + mkldnn_args_map_t args = { + {MKLDNN_ARG_DIFF_DST, *(out_grad.GetMKLDNNData())}, + {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second }, + }; + if (MKLDNNRequireWorkspace(param) && workspace != nullptr) { + args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData()); + } - bwd.SetNewMem(workspace, out_grad, diff_src_mem.second); - MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), args); CommitOutput(in_grad, diff_src_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h index 726d72156718..c89e4585e85d 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h @@ -1,59 +1,61 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file mkldnn_reshape-inl.h - * \brief Function definition of mkldnn reshape operator - */ - -#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ -#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ - -#if MXNET_USE_MKLDNN == 1 -#include -#include "mkldnn_base-inl.h" -#include "../../tensor/matrix_op-inl.h" - -namespace mxnet { -namespace op { - -class MKLDNNReshapeFwd { - public: - MKLDNNReshapeFwd(const OpReqType &req, const NDArray &input, const NDArray &output); - int GetWorkspaceSize(); - void SetNewMem(const NDArray &input, const NDArray &output, void *workspace = nullptr); - void Execute(const NDArray &input, const NDArray &output, void *workspace = nullptr); - - private: - std::shared_ptr data_; - std::shared_ptr out_; - std::shared_ptr temp_; - std::vector prims_; -}; - -typedef OpSignature MKLDNNReshapeSignature; -MKLDNNReshapeFwd &GetReshapeForward(const OpReqType &req, const NDArray &input, - const NDArray &output); - -} // namespace op -} // namespace mxnet - -#endif // MXNET_USE_MKLDNN == 1 -#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_reshape-inl.h + * \brief Function definition of mkldnn reshape operator + */ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ + +#if MXNET_USE_MKLDNN == 1 +#include +#include "mkldnn_base-inl.h" +#include "../../tensor/matrix_op-inl.h" + +namespace mxnet { +namespace op { + +class MKLDNNReshapeFwd { + protected: + std::shared_ptr out_; + std::shared_ptr temp_; + std::vector prims_; + + public: + MKLDNNReshapeFwd(const OpReqType &req, + const NDArray &input, + const NDArray &output); + int GetWorkspaceSize(); + void Execute(const NDArray &input, + const NDArray &output, + const OpReqType &req, + void* workspace = nullptr); +}; + +typedef OpSignature MKLDNNReshapeSignature; +MKLDNNReshapeFwd &GetReshapeForward(const OpReqType &req, const NDArray &input, + const NDArray &output); +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc index 9c226a052b0b..0fc9f20703af 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape.cc +++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc @@ -35,58 +35,65 @@ namespace op { MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req, const NDArray &input, const NDArray &output) { const auto engine = CpuEngine::Get()->get_engine(); - data_ = std::make_shared(input.GetMKLDNNData()->get_primitive_desc(), nullptr); + auto in_mem = input.GetMKLDNNData(); + // Create temp memory auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end()); auto temp_type = static_cast(get_mkldnn_type(input.dtype())); - auto temp_fmt = static_cast(GetDefaultFormat(input.shape().ndim())); + auto temp_fmt = static_cast(GetDefaultFormat(input.shape().ndim())); auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt); - auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine); - out_ = std::make_shared(temp_pd, nullptr); + + out_ = std::make_shared(temp_desc, engine, nullptr); if (req == kWriteInplace) { // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with // default layout and copy from the temporal buffer back to output buffer which has the same // address with input buffer. // If the input has default layout, then nothing need to do. if (input.IsMKLDNNData()) { - temp_ = std::make_shared(temp_pd, nullptr); - prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default + temp_ = std::make_shared(temp_desc, engine, nullptr); + prims_.push_back(mkldnn::reorder(*in_mem, *temp_)); // reorder to default prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back } } else if (req == kWriteTo) { - prims_.push_back(mkldnn::reorder(*data_, *out_)); + prims_.push_back(mkldnn::reorder(*in_mem, *out_)); } else { LOG(FATAL) << "not supported req type: " << req; } } int MKLDNNReshapeFwd::GetWorkspaceSize() { - return temp_ ? temp_->get_primitive_desc().get_size() : 0; -} - -void MKLDNNReshapeFwd::SetNewMem(const NDArray &input, - const NDArray &output, - void* workspace) { - this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle()); - this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle()); - if (workspace) { - this->temp_->set_data_handle(workspace); - } + return temp_ ? temp_->get_desc().get_size() : 0; } void MKLDNNReshapeFwd::Execute(const NDArray &input, const NDArray &output, + const OpReqType &req, void* workspace) { - if (this->prims_.size()) { - // set memory handles - SetNewMem(input, output, workspace); - // register primitives - auto stream = MKLDNNStream::Get(); - for (auto &v : this->prims_) { - stream->RegisterPrim(v); + auto stream = MKLDNNStream::Get(); + auto in_mem = input.GetMKLDNNData(); + // register primitives and arguments + std::vector args_map; + size_t prims_size = prims_.size(); + if (prims_size == 1) { + args_map.push_back({{MKLDNN_ARG_FROM, *in_mem}, + {MKLDNN_ARG_TO, *output.GetMKLDNNData()}}); + } else if (prims_size == 2) { + if (workspace) { + temp_->set_data_handle(workspace); } - stream->Submit(); + args_map.push_back({{MKLDNN_ARG_FROM, *in_mem}, + {MKLDNN_ARG_TO, *temp_}}); + args_map.push_back({{MKLDNN_ARG_FROM, *temp_}, + {MKLDNN_ARG_TO, *output.GetMKLDNNData()}}); + } else { + CHECK(prims_size == 0 && req != kWriteTo) + << "kWriteTo should never reach here."; + } + + for (size_t i = 0; i < prims_size; i++) { + stream->RegisterPrimArgs(prims_[i], args_map[i]); } + stream->Submit(); // invalidate mkldnn memory in output const_cast(output).InvalidateMKLDNNData(); } @@ -137,7 +144,7 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, .get_space_typed(mshadow::Shape1(ws_size), s); ws_ptr = static_cast(ws.dptr_); } - fwd.Execute(input, output, ws_ptr); + fwd.Execute(input, output, req, ws_ptr); } } // namespace op diff --git a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h new file mode 100644 index 000000000000..ad3f7332a8f4 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_rnn-inl.h + * \brief Common functions used by MKLDNN RNN operator + * \author Zixuan Wei +*/ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_INL_H_ + +#if MXNET_USE_MKLDNN == 1 + +#include +#include "../../rnn-inl.h" +#include "./mkldnn_base-inl.h" + +namespace mxnet { +namespace op { + +struct MKLDNNRnnLayerParam { + using memory = mkldnn::memory; + using dims = mkldnn::memory::dims; + + int mode; + bool bidirectional; + bool state_outputs; + int num_layer; + int batch_size; + int input_size; + int state_size; + int seq_len; + + dims src_dims; // Dimensions of source input in format_tag::tnc + dims weight_layer_dims; // Dimensions of layer weights in format_tag::ldigo + dims weight_iter_dims; // Dimensions of iter weights in format_tag::ldigo + dims bias_dims; // Dimensions of bias in format_tag::ldgo + dims dst_dims; // Dimensions of output in format_tag::tnc + dims state_dims; // Dimensions of the state cell in format_tag::ldnc + + size_t workspace_size; // used for the cached mkl-dnn memory in Forward inference + size_t reserve_size; // used for the reserved cached memory in Backward + size_t single_w_size; // weights size of a single cell + size_t single_b_size; // bias size of a single cell from mkl-dnn + size_t naive_single_b_size; // bias size of a single cell from framework + size_t single_state_size; // state size of a single cell, hy, cy + + MKLDNNRnnLayerParam(int num_layer, int batch_size, int seq_len, + int input_size, int state_size, + int mode, bool bidirectional = true) + : mode(mode), bidirectional(bidirectional), state_outputs(true), + num_layer(num_layer), batch_size(batch_size), input_size(input_size), + state_size(state_size), seq_len(seq_len) { } + + void SetDims(); +}; + +typedef std::vector LayerParamVector; +struct MKLDNNRnnFullParam { + RNNParam default_param; + LayerParamVector layer_params; +}; + +MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, const int seq_len, + const int batch_size, const int input_size); + +/* + * Use this to allocate memory from MKLDNNRnnOp temporary space. + */ +class MKLDNNRnnMemMgr { + // The memory buffer in NDArray life-cycle + NDArray workspace_; + // This points to the memory buffer from a NDArray + char *curr_mem; + // The total bytes of the workspace of a MKLDNNRnnOp + size_t mem_size = 0; + // The current available memory bytes + size_t curr_size = 0; + const size_t alignment = kMKLDNNAlign; + const mkldnn::engine& cpu_engine = CpuEngine::Get()->get_engine(); + // Here we hold all memory related to the stateful RNN operators + std::vector > mem_holder; + + public: + void Init(dim_t size, const Context& ctx, int dtype = mshadow::kFloat32); + + void RegisterMem(std::shared_ptr mem) { + mem_holder.push_back(mem); + } + + mkldnn::memory *Alloc(const mkldnn::memory::desc &md); +}; + +/* + * Rnn Primitive. + */ +class RnnPrimitive { + public: + /* Create a RnnPrimitive with rnn type: + * lstm_forward, lbr_gru_forward, vanilla_rnn_forward + */ + template + static RnnPrimitive Create(Args&&... args) { + RnnPrimitive rnn_fwd_prim; + rnn_fwd_prim.pd_.reset( + new typename rnn_fwd::desc(std::forward(args)...), + [](typename rnn_fwd::desc* pd) { + delete reinterpret_cast(pd); + }); + const typename rnn_fwd::desc& fwd_desc = + *(reinterpret_cast(rnn_fwd_prim.pd_.get())); + typename rnn_fwd::primitive_desc fwd_pd(fwd_desc, CpuEngine::Get()->get_engine()); + rnn_fwd_prim.weights_layer_desc_ = fwd_pd.weights_layer_desc(); + rnn_fwd_prim.weights_iter_desc_ = fwd_pd.weights_iter_desc(); + rnn_fwd_prim.workspace_desc_ = fwd_pd.workspace_desc(); + + rnn_fwd_prim.primitive_ = std::shared_ptr(new rnn_fwd(fwd_pd)); + + return rnn_fwd_prim; + } + + RnnPrimitive() { + this->pd_ = nullptr; + this->primitive_ = nullptr; + this->weights_layer_desc_ = mkldnn::memory::desc(); + this->weights_iter_desc_ = mkldnn::memory::desc(); + this->workspace_desc_ = mkldnn::memory::desc(); + } + + RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) { + this->pd_ = rnn_fwd_prim.pd_; + this->primitive_ = rnn_fwd_prim.primitive_; + this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_; + this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_; + this->workspace_desc_ = rnn_fwd_prim.workspace_desc_; + } + + RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) { + if (this != &rnn_fwd_prim) { + this->pd_ = rnn_fwd_prim.pd_; + this->primitive_ = rnn_fwd_prim.primitive_; + this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_; + this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_; + this->workspace_desc_ = rnn_fwd_prim.workspace_desc_; + } + + return *this; + } + + const void* GetPrimDesc() const { return pd_.get(); } + const mkldnn::primitive& GetPrim() const { return *primitive_; } + + const mkldnn::memory::desc& GetLayerDesc() const { + return weights_layer_desc_; + } + + const mkldnn::memory::desc& GetIterDesc() const { + return weights_iter_desc_; + } + + const mkldnn::memory::desc& GetWorkspaceDesc() const { + return workspace_desc_; + } + + private: + std::shared_ptr pd_; + std::shared_ptr primitive_; + mkldnn::memory::desc weights_layer_desc_; + mkldnn::memory::desc weights_iter_desc_; + mkldnn::memory::desc workspace_desc_; +}; + +RnnPrimitive GetRnnFwdPrim(const MKLDNNRnnLayerParam &layer_param, const bool is_train, + const NDArray &data, const NDArray ¶ms); + +/* + * Use this to manage memory and primitive of MKL-DNN RNN forward inference. + */ +class MKLDNNRnnForward { + public: + MKLDNNRnnForward(const MKLDNNRnnLayerParam &layer_param, const bool is_train, + const NDArray &data, const NDArray ¶ms) + : initialized_(false), param_(layer_param), + fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params)) { } + + void SetNewDataMem(void* x, void* hx, void* cx, + void* y, void* hy, void* cy, + const int dtype = mshadow::kFloat32); + void SetWeightsMem(MKLDNNRnnMemMgr* mgr, void* w_ptr, void* b_ptr, + const bool is_train = false, + const int dtype = mshadow::kFloat32); + void ReorderWeights(); + + const mkldnn::primitive& GetFwd() const { return fwd_inf_.GetPrim(); } + + const size_t GetSize(int dtype) const { + size_t bytes = mshadow::mshadow_sizeof(dtype); + size_t size = 0; + size += fwd_inf_.GetLayerDesc().get_size(); + size += fwd_inf_.GetIterDesc().get_size(); + return size / bytes + 1; + } + + const MKLDNNRnnLayerParam &GetParam() const { return param_; } + + const mkldnn_args_map_t &GetArgsMap() const { return net_args_; } + + const bool IsInitialized() const { return initialized_; } + void Reset() { initialized_ = false; } + + private: + bool initialized_; + MKLDNNRnnLayerParam param_; + RnnPrimitive fwd_inf_; // forward inference primitive + + mkldnn::memory *weights_layer_ = nullptr; + mkldnn::memory *weights_iter_ = nullptr; + mkldnn::memory *bias_ = nullptr; + + mkldnn::memory *weights_layer_r_ = nullptr; + mkldnn::memory *weights_iter_r_ = nullptr; + + /* + * net_args must contain some keys as below: + * MKLDNN_ARG_SRC + * MKLDNN_ARG_SRC_ITER + * MKLDNN_WEIGHTS_LAYER + * MKLDNN_WEIGHTS_ITER + * MKLDNN_BIAS + * MKLDNN_ARG_DST + * MKLDNN_ARG_DST_ITER + * if mode == Lstm, it also needs two additional key: + * MKLDNN_ARG_SRC_ITER_C + * MKLDNN_ARG_DST_ITER_C + */ + mkldnn_args_map_t net_args_; + + friend class MKLDNNRnnForwardTraining; +}; + +typedef std::shared_ptr mkldnn_shared_mem_t; +/* + * Use this to manage memory and primitive of MKL-DNN RNN forward training. + */ +class MKLDNNRnnForwardTraining { + public: + MKLDNNRnnForwardTraining(const MKLDNNRnnLayerParam &layer_param, const bool is_train, + const NDArray &data, const NDArray ¶ms) + : fwd_trn_(GetRnnFwdPrim(layer_param, is_train, data, params)), + param_(&layer_param) { } + + void SetTrnMem(const MKLDNNRnnForward& fwd); + void FetchData(const MKLDNNRnnForward& fwd); + + const MKLDNNRnnLayerParam& GetParam() const { return *param_; } + const void* GetPrimDesc() const { return fwd_trn_.GetPrimDesc(); } + const mkldnn::primitive& GetFwd() const { return fwd_trn_.GetPrim(); } + const mkldnn_args_map_t& GetArgsMap() const { return net_args_; } + + private: + RnnPrimitive fwd_trn_; + const MKLDNNRnnLayerParam* param_; + + mkldnn_shared_mem_t weights_layer_ = nullptr; + mkldnn_shared_mem_t weights_iter_ = nullptr; + mkldnn::memory* bias_ = nullptr; + + mkldnn_shared_mem_t workspace_ = nullptr; + + // Key MKLDNN_ARGS_WORKSPACE must be included in forward training + mkldnn_args_map_t net_args_; + + friend class MKLDNNRnnBackward; +}; + +/* + * Rnn Backward primitive + */ +class RnnBwdPrimitive { + public: + template + static RnnBwdPrimitive Create(typename rnn_fwd::primitive_desc const & fwd_pd, Args&&... args) { + RnnBwdPrimitive bwd_prim; + typename rnn_bwd::desc bwd_desc(std::forward(args)...); + typename rnn_bwd::primitive_desc bwd_pd(bwd_desc, CpuEngine::Get()->get_engine(), fwd_pd); + bwd_prim.weights_layer_desc_ = bwd_pd.weights_layer_desc(); + bwd_prim.weights_iter_desc_ = bwd_pd.weights_iter_desc(); + bwd_prim.diff_weights_layer_desc_ = bwd_pd.diff_weights_layer_desc(); + bwd_prim.diff_weights_iter_desc_ = bwd_pd.diff_weights_iter_desc(); + bwd_prim.diff_bias_desc_ = bwd_pd.diff_bias_desc(); + + bwd_prim.primitive_ = std::shared_ptr(new rnn_bwd(bwd_pd)); + + return bwd_prim; + } + + RnnBwdPrimitive() { + this->primitive_ = nullptr; + this->weights_layer_desc_ = mkldnn::memory::desc(); + this->weights_iter_desc_ = mkldnn::memory::desc(); + this->diff_weights_layer_desc_ = mkldnn::memory::desc(); + this->diff_weights_iter_desc_ = mkldnn::memory::desc(); + this->diff_bias_desc_ = mkldnn::memory::desc(); + } + + RnnBwdPrimitive(const RnnBwdPrimitive& bwd) { + this->primitive_ = bwd.primitive_; + this->weights_layer_desc_ = bwd.weights_layer_desc_; + this->weights_iter_desc_ = bwd.weights_iter_desc_; + this->diff_weights_layer_desc_ = bwd.diff_weights_layer_desc_; + this->diff_weights_iter_desc_ = bwd.diff_weights_iter_desc_; + this->diff_bias_desc_ = bwd.diff_bias_desc_; + } + + RnnBwdPrimitive& operator=(const RnnBwdPrimitive& bwd) { + if (this != &bwd) { + this->primitive_ = bwd.primitive_; + this->weights_layer_desc_ = bwd.weights_layer_desc_; + this->weights_iter_desc_ = bwd.weights_iter_desc_; + this->diff_weights_layer_desc_ = bwd.diff_weights_layer_desc_; + this->diff_weights_iter_desc_ = bwd.diff_weights_iter_desc_; + this->diff_bias_desc_ = bwd.diff_bias_desc_; + } + + return *this; + } + + private: + std::shared_ptr primitive_; + mkldnn::memory::desc weights_layer_desc_; + mkldnn::memory::desc weights_iter_desc_; + mkldnn::memory::desc diff_weights_layer_desc_; + mkldnn::memory::desc diff_weights_iter_desc_; + mkldnn::memory::desc diff_bias_desc_; + friend class MKLDNNRnnBackward; +}; +RnnBwdPrimitive GetRnnBwdPrim(const MKLDNNRnnForwardTraining& fwd, + const NDArray& data, const NDArray& params); + +/* + * Use this to manage memory and primitive of MKL-DNN RNN backward. + */ +class MKLDNNRnnBackward { + public: + MKLDNNRnnBackward(const MKLDNNRnnForwardTraining& fwd, + const NDArray& data, const NDArray& params) + : bwd_(GetRnnBwdPrim(fwd, data, params)), + fwd_ptr_(&fwd) { } + + void FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd); + void SetWeightsGradsMem(); + void SetDataGradsMem(void* diff_src, void* diff_state, void* diff_statecell, + void* diff_out, void* diff_state_out, void* diff_statecell_out, + const int dtype = mshadow::kFloat32); + void CommitWeightsDiff(void* diff_weights, void* diff_bias, const int dtype = mshadow::kFloat32); + + const mkldnn::primitive& GetBwd() const { return *bwd_.primitive_; } + const mkldnn_args_map_t& GetArgsMap() const { return net_args_; } + + private: + bool initialized_; + RnnBwdPrimitive bwd_; + const MKLDNNRnnForwardTraining* fwd_ptr_; + + mkldnn_shared_mem_t weights_layer_; + mkldnn_shared_mem_t weights_iter_; + + mkldnn_shared_mem_t diff_weights_layer_; + mkldnn_shared_mem_t diff_weights_iter_; + mkldnn_shared_mem_t diff_bias_; + + mkldnn_args_map_t net_args_; +}; + +/* + * Use MKLDNNRnnOp to manage fused or unfused RNN layers. A MKLDNNRnnOp contains + * the parameter(s) and primitive(s) of RNN layer(s). According to the direction, + * input size, and state size, multple layers could be inplemented by unfused and + * fused layers - MKLDNNRnnForward, which holds the memory and forward primitive + * of MKL-DNN. + */ +class MKLDNNRnnOp { + public: + explicit MKLDNNRnnOp(const RNNParam ¶m, const int seq_len, + const int batch_size, const int input_size) + : initialized_(false), weights_version_(0), + full_param_(MKLDNNRnnFullParamParser(param, seq_len, batch_size, input_size)) { } + + void Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + void Backward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + const RNNParam& GetParam() const { return full_param_.default_param; } + + private: + bool initialized_; + size_t weights_version_; + MKLDNNRnnFullParam full_param_; + MKLDNNRnnMemMgr mgr_; + std::vector fwd_inf_vec_; // forward inference layers + std::vector fwd_trn_vec_; // forward training layers + std::vector bwd_vec_; // backward layers + + // Used to store the intermediate results of multi-layer + std::vector dst_; + + // Used to store the intermediate diff_src of multi_layer + mkldnn_shared_mem_t diff_src; + + void Init(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc new file mode 100644 index 000000000000..e797b649d295 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc @@ -0,0 +1,1118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_rnn.cc + * \brief Common functions used by MKLDNN RNN operator + * \author Zixuan Wei +*/ + +#if MXNET_USE_MKLDNN == 1 + +#include +#include "./mkldnn_rnn-inl.h" + +namespace mxnet { +namespace op { + +inline int GetRnnGatesNum(int mode) { + switch (mode) { + case rnn_enum::kLstm: + return 4; + case rnn_enum::kGru: + return 3; + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + return 1; + default: + LOG(FATAL) << "unsupported RNN mode:" << mode; + return -1; + } +} + +void MKLDNNRnnLayerParam::SetDims() { + const int ngates = GetRnnGatesNum(mode); + //* NOTES: LBR-GRU's new gate formula needs two bias. So it has one more bias with LBR-GRU + const int nbias = mode == rnn_enum::kGru ? (ngates + 1) : ngates; + const int num_direction = bidirectional ? 2 : 1; + + src_dims.assign({seq_len, batch_size, input_size}); + weight_layer_dims.assign({num_layer, num_direction, input_size, ngates, state_size}); + weight_iter_dims.assign({num_layer, num_direction, state_size, ngates, state_size}); + bias_dims.assign({num_layer, num_direction, nbias, state_size}); + dst_dims.assign({seq_len, batch_size, state_size * num_direction}); + state_dims.assign({num_layer, num_direction, batch_size, state_size}); + + // unidirectional size of a single cell + single_w_size = (input_size + state_size) * ngates * state_size; + single_b_size = nbias * state_size; + naive_single_b_size = ngates * state_size * 2; // naive RNN variants have double bias + single_state_size = batch_size * state_size; + + // Get workspace size for cached weights memory + // multiplication of tensor dimensions + static auto tz_volume = [](const memory::dims& tz_dims) { + return std::accumulate(tz_dims.begin(), tz_dims.end(), static_cast(1), + std::multiplies()); + }; + + workspace_size = tz_volume(weight_layer_dims) + tz_volume(weight_iter_dims) + + tz_volume(bias_dims); + reserve_size = 0; +} + +MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, const int seq_len, + const int batch_size, const int input_size) { + MKLDNNRnnFullParam full_param; + full_param.default_param = rnn_param; + size_t state_size = rnn_param.state_size; + LayerParamVector &layer_params = full_param.layer_params; + + full_param.default_param.seq_length_ = seq_len; + full_param.default_param.batch_size_ = batch_size; + full_param.default_param.input_size_ = input_size; + // Set basic size by constructing MKLDNNRnnLayerParam instance(s) + if (rnn_param.bidirectional) { // unfused bidirectional multi-layer RNN + layer_params.emplace_back(1, batch_size, seq_len, input_size, state_size, rnn_param.mode); + for (size_t layer = 1; layer < rnn_param.num_layers; ++layer) { + layer_params.emplace_back(1, batch_size, seq_len, state_size * 2, state_size, + rnn_param.mode); + } + } else if (input_size == static_cast(state_size)) { // fused multi-layer RNN + layer_params.emplace_back(rnn_param.num_layers, batch_size, seq_len, input_size, + state_size, rnn_param.mode, false); + } else { // unfused 1st layer, plus fused 2-end layers + layer_params.emplace_back(1, batch_size, seq_len, input_size, state_size, rnn_param.mode, + false); + if (rnn_param.num_layers > 1) + layer_params.emplace_back(rnn_param.num_layers - 1, batch_size, seq_len, state_size, + state_size, rnn_param.mode, false); + } + + // Set dims, workspace size, and state_outputs flag + for (auto& layer_param : layer_params) { + layer_param.SetDims(); + layer_param.state_outputs = rnn_param.state_outputs; + } + return full_param; +} + +void MKLDNNRnnMemMgr::Init(dim_t size, const Context& ctx, int dtype) { + workspace_ = NDArray(TShape({size}), ctx, false, dtype); + curr_mem = static_cast(workspace_.data().dptr_); + mem_size = size * mshadow::mshadow_sizeof(dtype); + curr_size = size * mshadow::mshadow_sizeof(dtype); +} + +mkldnn::memory *MKLDNNRnnMemMgr::Alloc(const mkldnn::memory::desc &md) { + if (curr_mem == nullptr) { + curr_mem = static_cast(workspace_.data().dptr_); + } + + mkldnn_mem_ptr ret(new mkldnn::memory()); + size_t addr = reinterpret_cast(curr_mem); + size_t last_chunk = addr % alignment; + size_t padding = alignment - last_chunk; + addr += padding; + CHECK_EQ(addr % alignment, 0); + + curr_size -= (md.get_size() + padding); + if (curr_size < 0) { + ret.reset(new mkldnn::memory(md, cpu_engine)); + } else { + curr_mem += (md.get_size() + padding); + ret.reset(new mkldnn::memory(md, cpu_engine, reinterpret_cast(addr))); + } + RegisterMem(ret); + return ret.get(); +} + +RnnPrimitive GetRnnFwdPrim( + const MKLDNNRnnLayerParam &layer_param, const bool is_train, + const NDArray &data, const NDArray ¶ms) { + using namespace mkldnn; + using tag = mkldnn::memory::format_tag; + const int mode = layer_param.mode; + memory::data_type data_type = get_mkldnn_type(data.dtype()); + memory::data_type weight_type = get_mkldnn_type(params.dtype()); + const prop_kind prop = is_train ? prop_kind::forward_training : prop_kind::forward_inference; + const rnn_direction mkldnn_rnn_direction = layer_param.bidirectional ? + rnn_direction::bidirectional_concat : rnn_direction::unidirectional; + + auto src_layer_desc = memory::desc(layer_param.src_dims, data_type, tag::tnc); + auto weight_layer_desc = memory::desc(layer_param.weight_layer_dims, weight_type, tag::any); + auto weight_iter_desc = memory::desc(layer_param.weight_iter_dims, weight_type, tag::any); + auto bias_desc = memory::desc(layer_param.bias_dims, data_type, tag::ldgo); + auto dst_layer_desc = memory::desc(layer_param.dst_dims, data_type, tag::tnc); + auto src_state_desc = memory::desc(layer_param.state_dims, data_type, tag::ldnc); + auto dst_state_desc = layer_param.state_outputs ? memory::desc( + layer_param.state_dims, data_type, tag::ldnc) : memory::desc(); + + auto fwd = RnnPrimitive(); + switch (mode) { + case rnn_enum::kLstm: + fwd = RnnPrimitive::Create(prop, mkldnn_rnn_direction, + src_layer_desc, src_state_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc, + dst_state_desc); + break; + case rnn_enum::kGru: + fwd = RnnPrimitive::Create(prop, mkldnn_rnn_direction, + src_layer_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc); + break; + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + fwd = RnnPrimitive::Create(prop, + mode == rnn_enum::kRnnTanh ? algorithm::eltwise_tanh : algorithm::eltwise_relu, + mkldnn_rnn_direction, src_layer_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc); + break; + default: + LOG(FATAL) << "unsupported RNN mode:" << mode; + break; + } + return fwd; +} + +RnnBwdPrimitive GetRnnBwdPrim(const MKLDNNRnnForwardTraining &fwd, + const NDArray &data, const NDArray ¶ms) { + using namespace mkldnn; + using tag = mkldnn::memory::format_tag; + const MKLDNNRnnLayerParam& layer_param = fwd.GetParam(); + const int mode = layer_param.mode; + memory::data_type data_type = get_mkldnn_type(data.dtype()); + memory::data_type weight_type = get_mkldnn_type(params.dtype()); + const prop_kind prop = prop_kind::backward; + rnn_direction mkldnn_rnn_direction = layer_param.bidirectional ? + rnn_direction::bidirectional_concat : rnn_direction::unidirectional; + + auto src_layer_desc = memory::desc(layer_param.src_dims, data_type, tag::tnc); + auto weight_layer_desc = memory::desc(layer_param.weight_layer_dims, weight_type, tag::any); + auto weight_iter_desc = memory::desc(layer_param.weight_iter_dims, weight_type, tag::any); + auto bias_desc = memory::desc(layer_param.bias_dims, data_type, tag::ldgo); + auto dst_layer_desc = memory::desc(layer_param.dst_dims, data_type, tag::tnc); + auto src_state_desc = memory::desc(layer_param.state_dims, data_type, tag::ldnc); + auto dst_state_desc = layer_param.state_outputs ? memory::desc( + layer_param.state_dims, data_type, tag::ldnc) : memory::desc(); + + const void* fwd_desc = fwd.GetPrimDesc(); + auto bwd = RnnBwdPrimitive(); + switch (mode) { + case rnn_enum::kLstm: { + const lstm_forward::primitive_desc* desc = + reinterpret_cast(fwd_desc); + bwd = RnnBwdPrimitive::Create(*desc, + prop, mkldnn_rnn_direction, + // data desc + src_layer_desc, src_state_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc, + dst_state_desc, + // diff desc + src_layer_desc, src_state_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc, + dst_state_desc); + } break; + case rnn_enum::kGru: { + const lbr_gru_forward::primitive_desc* desc = + reinterpret_cast(fwd_desc); + bwd = RnnBwdPrimitive::Create(*desc, + prop, mkldnn_rnn_direction, + // data desc + src_layer_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc, + // diff desc + src_layer_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc); + } break; + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: { + const vanilla_rnn_forward::primitive_desc* desc = + reinterpret_cast(fwd_desc); + bwd = RnnBwdPrimitive::Create( + *desc, prop, + mode == rnn_enum::kRnnTanh ? algorithm::eltwise_tanh : algorithm::eltwise_relu, + mkldnn_rnn_direction, + // data desc + src_layer_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc, + // diff desc + src_layer_desc, src_state_desc, weight_layer_desc, + weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc); + } break; + default: + LOG(FATAL) << "unsupported RNN mode:" << mode; + break; + } + return bwd; +} + +/* + * Naive weights layout is: + * | l0_l2r_wx | l0_l2r_wh | l0_r2l_wx | l0_r2l_wh | + * | l1_l2r_wx | l1_l2r_wh | l1_r2l_wx | l1_r2l_wh | + * ... + * + * We need concat them to be: + * | l0_l2r_wx | l0_r2l_wx | l1_l2r_wx | l1_r2l_wx | + * | l0_l2r_wh | l0_r2l_wh | l1_l2r_wh | l1_r2l_wh | + * ... + * + * All the memory blocks are in goi format. + */ +static void ConcatWeights(const mkldnn::memory &dst, + const int concat_dimension, + const std::vector &src_ptrs, + const mkldnn::memory::format_tag src_format) { + using memory = mkldnn::memory; + auto cpu_engine = dst.get_engine(); + mkldnn::stream s(cpu_engine); + const memory::desc& dst_desc = dst.get_desc(); + // Use dst memory dims to initialize src memory dims, then set the concat + // dim to 1. And Rnn weights are 5-dimension tensor. + memory::dims src_dims(dst_desc.data.dims, dst_desc.data.dims + 5); + src_dims.at(concat_dimension) = 1; + std::vector src_descs; + std::unordered_map concat_args; + + for (size_t i = 0; i < src_ptrs.size(); ++i) { + src_descs.emplace_back(src_dims, + static_cast(dst_desc.data.data_type), src_format); + concat_args.emplace(MKLDNN_ARG_MULTIPLE_SRC + i, + memory(src_descs.back(), cpu_engine, src_ptrs.at(i))); + } + concat_args.emplace(MKLDNN_ARG_DST, dst); + + auto concat_pd = mkldnn::concat::primitive_desc(dst.get_desc(), + concat_dimension, src_descs, cpu_engine); + mkldnn::concat(concat_pd).execute(s, concat_args); +} + +#define RNN_HANDLE_FUNC_NAME set_handle +#define RNN_HANDLE_FUNC(RNN_FUNC_NAME) \ +auto RNN_FUNC_NAME = [&cpu_engine, &args](int arg_name, const desc& md, \ + void* handle) { \ + if (args.find(arg_name) != args.end()) { \ + if (handle != nullptr) args.at(arg_name).set_data_handle(handle); \ + } else { \ + args[arg_name] = handle ? mkldnn::memory(md, cpu_engine, handle) \ + : mkldnn::memory(md, cpu_engine); \ + } \ +} + +#define RNN_FWD_SET(NAME, DIMS, TAG, HANDLE, DTYPE) \ +RNN_FWD_SET_(RNN_HANDLE_FUNC_NAME, NAME, DIMS, TAG, HANDLE, DTYPE) + +#define RNN_FWD_SET_(FUNC, NAME, DIMS, TAG, HANDLE, DTYPE) \ +FUNC(MKLDNN_ARG_##NAME, {DIMS, get_mkldnn_type(DTYPE), TAG}, HANDLE) + +#define RNN_BWD_SET(NAME, ARGS, HANDLE) \ +RNN_BWD_SET_(RNN_HANDLE_FUNC_NAME, NAME, ARGS, HANDLE) + +#define RNN_BWD_SET_(FUNC, NAME, ARGS, HANDLE) \ +FUNC(MKLDNN_ARG_DIFF_##NAME, ARGS.at(MKLDNN_ARG_##NAME).get_desc(), HANDLE) + +/* + * Set new src data handler to Forward memory. The memory primitives are + * not initialized until SetNewDataMem is first invoked. Src data handler + * must not be nullptr, except for cx with LSTM mode. If either hy, cy is + * nullptr, it may run with non-state_ouput or non-LSTM mode. Thus, the + * corresponding memory should be a empty mkldnn::memory(). + */ +void MKLDNNRnnForward::SetNewDataMem(void* x, void* hx, void* cx, + void* y, void* hy, void* cy, + const int dtype) { + using dims = mkldnn::memory::dims; + using desc = mkldnn::memory::desc; + using format_tag = mkldnn::memory::format_tag; + auto& cpu_engine = CpuEngine::Get()->get_engine(); + mkldnn_args_map_t& args = net_args_; + + RNN_HANDLE_FUNC(RNN_HANDLE_FUNC_NAME); + + // Set various data memory + RNN_FWD_SET(SRC, param_.src_dims, format_tag::tnc, x, dtype); + RNN_FWD_SET(DST, param_.dst_dims, format_tag::tnc, y, dtype); + RNN_FWD_SET(SRC_ITER, param_.state_dims, format_tag::ldnc, hx, dtype); + + if (param_.state_outputs) { + RNN_FWD_SET(DST_ITER, param_.state_dims, format_tag::ldnc, hy, dtype); + } + + if (param_.mode == rnn_enum::kLstm) { + RNN_FWD_SET(SRC_ITER_C, param_.state_dims, format_tag::ldnc, cx, dtype); + if (param_.state_outputs) { + RNN_FWD_SET(DST_ITER_C, param_.state_dims, format_tag::ldnc, cy, dtype); + } + } +} + +/* + * Reorder the concatenated weights memory to a efficient memory block + * with primitive-prefered format. + */ +void MKLDNNRnnForward::ReorderWeights() { + auto& cpu_engine = CpuEngine::Get()->get_engine(); + mkldnn::stream s(cpu_engine); + mkldnn::reorder(*weights_layer_r_, *weights_layer_) + .execute(s, *weights_layer_r_, *weights_layer_); + mkldnn::reorder(*weights_iter_r_, *weights_iter_) + .execute(s, *weights_iter_r_, *weights_iter_); + s.wait(); +} + +void AdjustGruGateOrder(char* weight, + const size_t input_size, + const size_t hidden_size, + const int dtype) { + // mxnet gru gate order is reset, update and new gates + // mkldnn gru gate order is update, reset and new gates + size_t single_weight_bytes = input_size * hidden_size * mshadow::mshadow_sizeof(dtype); + char* weight_reset = weight; + char* weight_update = weight + single_weight_bytes; + std::swap_ranges(weight_reset, weight_update, weight_update); +} + +/* + * Fuse uni-directional bias among single layer. + */ +template +void FuseBias(DType* fuse_bias, DType* naive_bias, + const int mode, const size_t state_size) { + const size_t ngates = GetRnnGatesNum(mode); + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + const size_t nbias = mode == rnn_enum::kGru ? ngates + 1 : ngates; + // MSVC-14.0 (OpenMP 2.0 compatible) doesn't support unsigned integral type in + // OpenMP 'for' statement. + const int state_size_ = static_cast(state_size); + const int single_b_sz = static_cast(nbias * state_size); + DType* bx = naive_bias; + DType* bh = naive_bias + state_size * ngates; + if (mode == rnn_enum::kGru) { + // While mxnet gru gate order is reset, update and new gates, + // mkldnn gru gate order is update, reset and new gates. So + // we need to swap the order of reset and update from mxnet. + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < state_size_; j++) { + // Swap summed reset, update bias + fuse_bias[j + state_size] = bx[j] + bh[j]; + fuse_bias[j] = bx[j + state_size] + bh[j + state_size]; + + // Memcpy two new gates + fuse_bias[j + 2 * state_size] = bx[j + 2 * state_size]; + fuse_bias[j + 3 * state_size] = bh[j + 2 * state_size]; + } + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < single_b_sz; ++j) { + // Sum two bias + fuse_bias[j] = bx[j] + bh[j]; + } + } +} + +inline void EmplaceNetArgs(mkldnn_args_map_t* net_args, const int arg_name, + const mkldnn::memory* mem) { + if (net_args->find(arg_name) != net_args->end()) { + if (net_args->at(arg_name).get_data_handle() == mem->get_data_handle()) { + return; + } else { + net_args->at(arg_name).set_data_handle(mem->get_data_handle()); + } + } else { + net_args->emplace(arg_name, *mem); + } +} + +/* + * Copy naive memory to mkldnn-format memory. It will initialize the memory + * when first invoked. Then, the naive weight_layer and weight_iter are + * concatenated to xxx_xx_r memory. Per the different gates order of GRU, + * it will swap the memory blocks of gates among concatenated memory + * inplace. From then on, the xxx_xx_r memory is reordered to target + * memory with preferred format_tag. Finally, naive bias is fused to MKLDNN + * bias memory. + */ +void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_ptr, + const bool is_train, const int dtype) { + using format_tag = mkldnn::memory::format_tag; + auto mkldnn_dtype = get_mkldnn_type(dtype); + // Get the weights' memory for RNN forward primitive + if (weights_layer_ == nullptr) { + weights_layer_ = mgr->Alloc(fwd_inf_.GetLayerDesc()); + } + if (weights_iter_ == nullptr) { + weights_iter_ = mgr->Alloc(fwd_inf_.GetIterDesc()); + } + if (bias_ == nullptr) { + bias_ = mgr->Alloc( + {param_.bias_dims, mkldnn_dtype, format_tag::ldgo}); + } + + // Get the intermediate memory for weights concat & reorder + if (weights_layer_r_ == nullptr) { + weights_layer_r_ = mgr->Alloc( + {param_.weight_layer_dims, mkldnn_dtype, format_tag::ldgoi}); + } + if (weights_iter_r_ == nullptr) { + weights_iter_r_ = mgr->Alloc( + {param_.weight_iter_dims, mkldnn_dtype, format_tag::ldgoi}); + } + + // Get the bytes of a real type + size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + + // convert void* to char* for arithmetic operations + char *weights_ptr = static_cast(w_ptr); + size_t wx_bytes = GetRnnGatesNum(param_.mode) * param_.state_size * + param_.input_size * dtype_bytes; //* DIMS: ngates x state_size x input_size + size_t wh_bytes = GetRnnGatesNum(param_.mode) * param_.state_size * + param_.state_size * dtype_bytes; //* DIMS: ngates x state_size x state_size + char *l2r_wx = weights_ptr; + char *l2r_wh = l2r_wx + wx_bytes; //* DIMS: ngates x state_size * state_size + + if (param_.num_layer == 1 && param_.bidirectional) { + //* single bidirectinal layer, concat weights on direction axis + char *r2l_wx = weights_ptr + param_.single_w_size * dtype_bytes; + char *r2l_wh = r2l_wx + wx_bytes; //* DIMS: ngates x state_size * state_size + ConcatWeights(*weights_layer_r_, 1, {l2r_wx, r2l_wx}, format_tag::ldgoi); + ConcatWeights(*weights_iter_r_, 1, {l2r_wh, r2l_wh}, format_tag::ldgoi); + } else if (param_.num_layer == 1 && !param_.bidirectional) { + //* single uni-directional layer, no concatenate operator needed + std::memcpy(weights_layer_r_->get_data_handle(), l2r_wx, wx_bytes); + std::memcpy(weights_iter_r_->get_data_handle(), l2r_wh, wh_bytes); + } else if (param_.num_layer > 1 && !param_.bidirectional) { + //* concat fused multi-layer weights on layer axis + std::vector l2r_wx_ptrs; + std::vector l2r_wh_ptrs; + for (int lyr = 0; lyr < param_.num_layer; ++lyr) { + char *lth_wx = l2r_wx + lyr * param_.single_w_size * dtype_bytes; + char *lth_wh = lth_wx + wx_bytes; + l2r_wx_ptrs.push_back(lth_wx); + l2r_wh_ptrs.push_back(lth_wh); + } + ConcatWeights(*weights_layer_r_, 0, l2r_wx_ptrs, format_tag::ldgoi); + ConcatWeights(*weights_iter_r_, 0, l2r_wh_ptrs, format_tag::ldgoi); + } else { + LOG(FATAL) << "Undifined RNN fusion workflow for num_layer = " << param_.num_layer + << ", and bidirectional is " << param_.bidirectional; + } + + // Adjust gates order of LBR-GRU among concatenated memory inplace. + char* fused_wx = static_cast(weights_layer_r_->get_data_handle()); + char* fused_wh = static_cast(weights_iter_r_->get_data_handle()); + if (param_.mode == rnn_enum::kGru) { + for (size_t lyr = 0; lyr < static_cast(param_.num_layer); ++lyr) { + for (size_t d = 0; d < param_.bidirectional + 1U; ++d) { + AdjustGruGateOrder(fused_wx, param_.input_size, param_.state_size, dtype); + AdjustGruGateOrder(fused_wh, param_.state_size, param_.state_size, dtype); + fused_wx += wx_bytes; + fused_wh += wh_bytes; + } + } + } + // Reorder after adjustment only when is_train == false. When is_train == true, i.e. + // in forward training path, we use plain memory (ldxxx) as the space for weights and + // their gradients. Then, forward training primitives could fetch them from the scope + // of forward inference. And from there, we don't need to reorder the plain memory to + // the optimal rnn-packed memory for forward inference. + ReorderWeights(); + + // Process bias + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType* naive_b_ptr = static_cast(b_ptr); + DType* fused_bias = static_cast(bias_->get_data_handle()); + for (int lyr = 0; lyr < param_.num_layer; ++lyr) { + for (int d = 0; d < param_.bidirectional + 1; ++d) { + FuseBias(fused_bias, naive_b_ptr, param_.mode, param_.state_size); + fused_bias += param_.single_b_size; + naive_b_ptr += param_.naive_single_b_size; + } + } + }); + + // insert weights into net_args + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_LAYER, this->weights_layer_); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_ITER, this->weights_iter_); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_BIAS, this->bias_); + + initialized_ = true; +} + +void MKLDNNRnnForwardTraining::SetTrnMem(const MKLDNNRnnForward& fwd) { + using memory = mkldnn::memory; + const auto& cpu_engine = CpuEngine::Get()->get_engine(); + auto s = mkldnn::stream(cpu_engine); + // Prepare mkldnn::memorys for weights_layer, weight_iter, and workspace + if (workspace_ == nullptr) + workspace_ = mkldnn_shared_mem_t(new memory(fwd_trn_.GetWorkspaceDesc(), cpu_engine)); + if (weights_layer_ == nullptr) + weights_layer_ = mkldnn_shared_mem_t(new memory(fwd_trn_.GetLayerDesc(), cpu_engine)); + if (weights_iter_ == nullptr) + weights_iter_ = mkldnn_shared_mem_t(new memory(fwd_trn_.GetIterDesc(), cpu_engine)); + + // fill weights memory using the reordered weights of fwd_inference primitive + if (fwd.weights_layer_r_->get_desc() == fwd_trn_.GetLayerDesc()) { + weights_layer_->set_data_handle(fwd.weights_layer_r_->get_data_handle()); + } else { + mkldnn::reorder(*fwd.weights_layer_r_, *weights_layer_) + .execute(s, *fwd.weights_layer_r_, *weights_layer_); + } + + if (fwd.weights_iter_r_->get_desc() == fwd_trn_.GetIterDesc()) { + weights_iter_->set_data_handle(fwd.weights_iter_r_->get_data_handle()); + } else { + mkldnn::reorder(*fwd.weights_iter_r_, *weights_iter_) + .execute(s, *fwd.weights_iter_r_, *weights_iter_); + } + s.wait(); + + // bias are always in format_tag::ldgo + this->bias_ = fwd.bias_; + + // insert weights into net_args + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_LAYER, this->weights_layer_.get()); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_ITER, this->weights_iter_.get()); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_BIAS, this->bias_); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WORKSPACE, this->workspace_.get()); +} + +void MKLDNNRnnForwardTraining::FetchData(const MKLDNNRnnForward& fwd) { + for (auto& kv : fwd.net_args_) { + switch (kv.first) { + case MKLDNN_ARG_WEIGHTS_LAYER: + case MKLDNN_ARG_WEIGHTS_ITER: + case MKLDNN_ARG_BIAS: + case MKLDNN_ARG_WORKSPACE: + continue; + + default: + EmplaceNetArgs(&this->net_args_, kv.first, &kv.second); + } + } +} + +void MKLDNNRnnOp::Init(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using memory = mkldnn::memory; + using format_tag = mkldnn::memory::format_tag; + + // In the `autograd.record()` context, RNNOp is required to run into + // `forward_training` mode. + const bool is_training = (ctx.is_train || ctx.need_grad); + const size_t num_fusion = full_param_.layer_params.size(); + if (fwd_inf_vec_.size() < num_fusion) { + size_t buffer_size = 0; // Element number, instead of bytes, in the buffer + for (auto& layer_param : full_param_.layer_params) { + buffer_size += layer_param.workspace_size + layer_param.reserve_size; + } + buffer_size += outputs[rnn_enum::kOut].data().Size() * (num_fusion - 1); + buffer_size += kMKLDNNAlign * num_fusion * 5; // Add margin for alignment + + for (auto& layer_param : full_param_.layer_params) { + fwd_inf_vec_.emplace_back(layer_param, + ctx.is_train, inputs[rnn_enum::kData], inputs[rnn_enum::kParams]); + buffer_size += fwd_inf_vec_.back().GetSize(inputs[rnn_enum::kParams].dtype()); + } + mgr_.Init(buffer_size, ctx.run_ctx.ctx, inputs[rnn_enum::kParams].dtype()); + } + + if (is_training && fwd_trn_vec_.size() < num_fusion) { + for (auto& layer_param : full_param_.layer_params) { + fwd_trn_vec_.emplace_back(layer_param, + true, inputs[rnn_enum::kData], inputs[rnn_enum::kParams]); + } + } + + // Get the bytes of a real type + const NDArray &weights = inputs[rnn_enum::kParams]; + int dtype = weights.dtype(); + size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + + const RNNParam &default_param = full_param_.default_param; + char *weights_ptr = static_cast(weights.data().dptr_); + char *bias_ptr = weights_ptr + (weights.data().Size() - + GetRnnBiasSize(default_param.num_layers, default_param.state_size, + default_param.bidirectional + 1, default_param.mode)) * dtype_bytes; + for (auto& fwd_layer : fwd_inf_vec_) { + size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes; + size_t single_b_bytes = fwd_layer.GetParam().naive_single_b_size * dtype_bytes; + size_t directions = fwd_layer.GetParam().bidirectional ? 2 : 1; + size_t layer_weights_bytes = single_w_bytes * directions; + size_t layer_bias_bytes = single_b_bytes * directions; // Naive MXNet has double bias + + if (!fwd_layer.IsInitialized() || is_training) + fwd_layer.SetWeightsMem(&(this->mgr_), weights_ptr, bias_ptr, is_training, dtype); + weights_ptr += layer_weights_bytes; + bias_ptr += layer_bias_bytes; + } + + if (is_training) { + CHECK_EQ(fwd_trn_vec_.size(), fwd_inf_vec_.size()) << + "Layers' configurations of forward inference and forward training are disparate."; + for (size_t lyr = 0; lyr < fwd_inf_vec_.size(); ++lyr) + fwd_trn_vec_.at(lyr).SetTrnMem(fwd_inf_vec_.at(lyr)); + } + + CHECK_EQ(num_fusion, fwd_inf_vec_.size()) << + "Layer vector's size has a different value than the number of fusion."; + if (dst_.size() < num_fusion - 1) { + int data_dtype = outputs[rnn_enum::kOut].dtype(); + // Here we need `fwd_inf_vec_.size() - 1` spaces for the intermediate results of the multiple + // fused layers. And for the result of the last fused layer, `outputs[rnn_enum::kOut]` could + // provide the space. Hence, `forward_inf_vec_.back()` is excluded when allocates the spaces + // for intermediate results. + for (std::vector::const_iterator fwd = fwd_inf_vec_.begin(); + fwd != fwd_inf_vec_.end() - 1; ++fwd) + dst_.push_back(mgr_.Alloc( + {fwd->GetParam().dst_dims, get_mkldnn_type(data_dtype), format_tag::tnc})); + } + + initialized_ = true; +} + +void MKLDNNRnnBackward::FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd) { + using memory = mkldnn::memory; + auto& cpu_engine = CpuEngine::Get()->get_engine(); + auto s = mkldnn::stream(cpu_engine); + + if (this->weights_layer_ == nullptr) + this->weights_layer_ = mkldnn_shared_mem_t(new memory(bwd_.weights_layer_desc_, cpu_engine)); + if (this->weights_iter_ == nullptr) + this->weights_iter_ = mkldnn_shared_mem_t(new memory(bwd_.weights_iter_desc_, cpu_engine)); + + for (auto& kv : fwd.net_args_) { + const mkldnn::memory* valid_mem; + switch (kv.first) { + case MKLDNN_ARG_WEIGHTS_LAYER: { + if (bwd_.weights_layer_desc_ == fwd.fwd_trn_.GetLayerDesc()) { + this->weights_layer_->set_data_handle(kv.second.get_data_handle()); + } else { + mkldnn::reorder(*fwd.weights_layer_, *this->weights_layer_) + .execute(s, *fwd.weights_layer_, *this->weights_layer_); + } + valid_mem = this->weights_layer_.get(); + } break; + case MKLDNN_ARG_WEIGHTS_ITER: { + if (bwd_.weights_iter_desc_ == fwd.fwd_trn_.GetLayerDesc()) { + this->weights_iter_->set_data_handle(kv.second.get_data_handle()); + } else { + mkldnn::reorder(*fwd.weights_iter_, *this->weights_iter_) + .execute(s, *fwd.weights_iter_, *this->weights_iter_); + } + valid_mem = this->weights_iter_.get(); + } break; + + default: + valid_mem = &kv.second; + } + EmplaceNetArgs(&this->net_args_, kv.first, valid_mem); + } + s.wait(); +} + +void MKLDNNRnnBackward::SetWeightsGradsMem() { + auto& cpu_engine = CpuEngine::Get()->get_engine(); + if (this->diff_weights_layer_ == nullptr) + this->diff_weights_layer_ = std::make_shared( + bwd_.diff_weights_layer_desc_, cpu_engine); + if (this->diff_weights_iter_ == nullptr) + this->diff_weights_iter_ = std::make_shared( + bwd_.diff_weights_iter_desc_, cpu_engine); + if (this->diff_bias_ == nullptr) + this->diff_bias_ = std::make_shared( + bwd_.diff_bias_desc_, cpu_engine); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_DIFF_WEIGHTS_LAYER, + this->diff_weights_layer_.get()); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_DIFF_WEIGHTS_ITER, + this->diff_weights_iter_.get()); + EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_DIFF_BIAS, + this->diff_bias_.get()); +} + +void MKLDNNRnnBackward::SetDataGradsMem( + void* diff_src, void* diff_state, void* diff_statecell, + void* diff_dst, void* diff_state_out, void* diff_statecell_out, + const int dtype) { + using desc = mkldnn::memory::desc; + auto& cpu_engine = CpuEngine::Get()->get_engine(); + mkldnn_args_map_t& args = this->net_args_; + + RNN_HANDLE_FUNC(RNN_HANDLE_FUNC_NAME); + + // Set various diff memory + auto& fwd_args = fwd_ptr_->GetArgsMap(); + RNN_BWD_SET(SRC, fwd_args, diff_src); + RNN_BWD_SET(SRC_ITER, fwd_args, diff_state); + RNN_BWD_SET(DST, fwd_args, diff_dst); + + if (fwd_ptr_->GetParam().state_outputs) + RNN_BWD_SET(DST_ITER, fwd_args, diff_state_out); + + if (fwd_ptr_->GetParam().mode == rnn_enum::kLstm) { + RNN_BWD_SET(SRC_ITER_C, fwd_args, diff_statecell); + if (fwd_ptr_->GetParam().state_outputs) { + RNN_BWD_SET(DST_ITER_C, fwd_args, diff_statecell_out); + } + } +} + +template +void HalveWeightsDiff(DType* w, const size_t size) { + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < static_cast(size); ++i) { + w[i] *= 0.5; + } +} + +void MKLDNNRnnBackward::CommitWeightsDiff(void* diff_weights, void* diff_bias, int dtype) { + using tag = mkldnn::memory::format_tag; + auto& cpu_engine = CpuEngine::Get()->get_engine(); + auto s = mkldnn::stream(cpu_engine); + + const MKLDNNRnnLayerParam& param = fwd_ptr_->GetParam(); + const int num_layer = param.num_layer; + const int direction = param.bidirectional ? 2 : 1; + const int ngates = GetRnnGatesNum(param.mode); + const size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + const size_t wxh_bytes = param.single_w_size * dtype_bytes; + const size_t wx_bytes = param.input_size * param.state_size * ngates * dtype_bytes; + const size_t wh_bytes = param.state_size * param.state_size * ngates * dtype_bytes; + char* diff_wx_ptr = static_cast(diff_weights_layer_->get_data_handle()); + char* diff_wh_ptr = static_cast(diff_weights_iter_->get_data_handle()); + + /* naive weights layout is: + 1st-layer: | wx_lr | wh_lr | wx_rl | wh_rl | + 2st-layer: | wx_lr | wh_lr | wx_rl | wh_rl | + size: | wxh_bytes | + |wx_bytes|wh_bytes| + */ + char* naive_weights = static_cast(diff_weights); + if (param.mode != rnn_enum::kGru) { + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_weights + shift * wxh_bytes, + diff_wx_ptr + shift * wx_bytes, wx_bytes); + } + // align naive_weights to weights_iter memory + naive_weights += wx_bytes; + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_weights + shift * wxh_bytes, + diff_wh_ptr + shift * wh_bytes, wh_bytes); + } + } else { + const size_t wx_bytes_per_gate = param.input_size * param.state_size * dtype_bytes; + const size_t wh_bytes_per_gate = param.state_size * param.state_size * dtype_bytes; + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_weights + shift * wxh_bytes + wx_bytes_per_gate, + diff_wx_ptr + shift * wx_bytes, wx_bytes_per_gate); + std::memcpy(naive_weights + shift * wxh_bytes, + diff_wx_ptr + shift * wx_bytes + wx_bytes_per_gate, wx_bytes_per_gate); + std::memcpy(naive_weights + shift * wxh_bytes + 2 * wx_bytes_per_gate, + diff_wx_ptr + shift * wx_bytes + 2 * wx_bytes_per_gate, wx_bytes_per_gate); + } + // align naive_weights to weights_iter memory + naive_weights += wx_bytes; + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_weights + shift * wxh_bytes + wh_bytes_per_gate, + diff_wh_ptr + shift * wh_bytes, wh_bytes_per_gate); + std::memcpy(naive_weights + shift * wxh_bytes, + diff_wh_ptr + shift * wh_bytes + wh_bytes_per_gate, wh_bytes_per_gate); + std::memcpy(naive_weights + shift * wxh_bytes + 2 * wh_bytes_per_gate, + diff_wh_ptr + shift * wh_bytes + 2 * wh_bytes_per_gate, wh_bytes_per_gate); + } + } + + char* naive_bias = static_cast(diff_bias); + char* diff_bias_ptr = static_cast(this->diff_bias_->get_data_handle()); + const size_t bias_bytes = param.single_b_size * dtype_bytes; + const size_t naive_bias_bytes = param.naive_single_b_size * dtype_bytes; + if (param.mode != rnn_enum::kGru) { + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + DType* typed_bias = reinterpret_cast(diff_bias_ptr); + HalveWeightsDiff(typed_bias, num_layer * direction * param.single_b_size); + }); + for (int shift = 0; shift < num_layer * direction; ++shift) { + std::memcpy(naive_bias + shift * naive_bias_bytes, + diff_bias_ptr + shift * bias_bytes, bias_bytes); + std::memcpy(naive_bias + shift * naive_bias_bytes + bias_bytes, + diff_bias_ptr + shift * bias_bytes, bias_bytes); + } + } else { + const size_t bias_bytes_per_gate = param.state_size * dtype_bytes; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + for (int shift = 0; shift < num_layer * direction; ++shift) { + char* naive_reset = naive_bias + shift * naive_bias_bytes; + char* naive_update = naive_reset + bias_bytes_per_gate; + char* update = diff_bias_ptr + shift * bias_bytes; + char* reset = update + bias_bytes_per_gate; + + DType* typed_update = reinterpret_cast(update); + HalveWeightsDiff(typed_update, param.state_size * 2); + + std::memcpy(naive_update, update, bias_bytes_per_gate); + std::memcpy(naive_reset, reset, bias_bytes_per_gate); + std::memcpy(naive_update + naive_bias_bytes / 2, update, bias_bytes_per_gate); + std::memcpy(naive_reset + naive_bias_bytes / 2, reset, bias_bytes_per_gate); + + char* naive_new_bx = naive_update + bias_bytes_per_gate; + char* naive_new_bh = naive_new_bx + naive_bias_bytes / 2; + char* new_bx = reset + bias_bytes_per_gate; + char* new_bh = new_bx + bias_bytes_per_gate; + std::memcpy(naive_new_bx, new_bx, bias_bytes_per_gate); + std::memcpy(naive_new_bh, new_bh, bias_bytes_per_gate); + } + }); + } +} + +template +inline void RegisterMKLDNNRnn(MKLDNNRnnX const& rnn) { + MKLDNNStream::Get()->RegisterPrimArgs(rnn.GetFwd(), rnn.GetArgsMap()); +} + +template <> +inline void RegisterMKLDNNRnn(MKLDNNRnnBackward const& rnn) { + MKLDNNStream::Get()->RegisterPrimArgs(rnn.GetBwd(), rnn.GetArgsMap()); +} + +void MKLDNNRnnOp::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + // In the `autograd.record()` context, RNNOp is required to run into + // forward_training mode. + const bool is_training = (ctx.is_train || ctx.need_grad); + // check output requests + if (kAddTo == req[rnn_enum::kOut]) + LOG(FATAL) << "Currently, `add` operation is not supported by RNNs."; + const RNNParam& default_param = full_param_.default_param; + if (default_param.state_outputs) { + if (kAddTo == req[rnn_enum::kStateOut]) + LOG(FATAL) << "Currently, `add` operation is not supported by RNNs."; + if (default_param.mode == rnn_enum::kLstm && kAddTo == req[rnn_enum::kStateCellOut]) + LOG(FATAL) << "Currently, `add` operation against lstm-cell output is not supported."; + } + + // Initialize weights version + if (!initialized_ && weights_version_ == 0) { + weights_version_ = inputs[rnn_enum::kParams].version(); + } + + // Check if weights NDArray was changed. If so, reset initialized_ + if (weights_version_ != inputs[rnn_enum::kParams].version() && + fwd_inf_vec_.size() > 0) { + initialized_ = false; + for (auto& fwd : fwd_inf_vec_) fwd.Reset(); + weights_version_ = inputs[rnn_enum::kParams].version(); + } + + if (!initialized_ || is_training || fwd_inf_vec_.size() == 0) { + Init(ctx, inputs, req, outputs); + } + + // Get data type + int data_dtype = inputs[rnn_enum::kData].dtype(); + + // Get input & output NDArray + char *src = static_cast(inputs[rnn_enum::kData].data().dptr_); + char *src_state = static_cast(inputs[rnn_enum::kState].data().dptr_); + char *dst = req[rnn_enum::kOut] == kNullOp ? nullptr + : static_cast(outputs[rnn_enum::kOut].data().dptr_); + char *dst_state = nullptr; // Output state + char *src_state_cell = nullptr; // Used in LSTM for cell state + char *dst_state_cell = nullptr; // Used in LSTM for cell state + + if (default_param.state_outputs && req[rnn_enum::kStateOut] != kNullOp) { + dst_state = static_cast(outputs[rnn_enum::kStateOut].data().dptr_); + } + + if (default_param.mode == rnn_enum::kLstm) { + src_state_cell = static_cast(inputs[rnn_enum::kStateCell].data().dptr_); + if (default_param.state_outputs && req[rnn_enum::kStateCellOut] != kNullOp) { + dst_state_cell = static_cast(outputs[rnn_enum::kStateCellOut].data().dptr_); + } + } + + if (fwd_inf_vec_.size() == 1) { + fwd_inf_vec_.front().SetNewDataMem(src, src_state, src_state_cell, + dst, dst_state, dst_state_cell, data_dtype); + if (is_training) { + fwd_trn_vec_.front().FetchData(fwd_inf_vec_.front()); + } + } else { + CHECK_EQ(fwd_inf_vec_.size(), dst_.size() + 1) << "Output memory error."; + size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ * + default_param.state_size * mshadow::mshadow_sizeof(data_dtype); + + // Set input data memory for the first layer. This stores intermediate output + // results in this->xxx, used as the source input of the next layer. + fwd_inf_vec_.front().SetNewDataMem(src, src_state, src_state_cell, + this->dst_.front()->get_data_handle(), dst_state, dst_state_cell, data_dtype); + if (is_training) { + fwd_trn_vec_.front().FetchData(fwd_inf_vec_.front()); + } + // 1st_lyr -> dst_handle -> next_lyr -> dst_handle -> next_lyr -> ... + for (size_t lyr = 1; lyr < fwd_inf_vec_.size() - 1; ++lyr) { + src_state += cell_bytes; + if (src_state_cell) src_state_cell += cell_bytes; + if (dst_state) dst_state += cell_bytes; + if (dst_state_cell) dst_state_cell += cell_bytes; + fwd_inf_vec_.at(lyr).SetNewDataMem(this->dst_.at(lyr - 1)->get_data_handle(), + src_state, src_state_cell, + this->dst_.at(lyr)->get_data_handle(), dst_state, dst_state_cell, data_dtype); + if (is_training) { + fwd_trn_vec_.at(lyr).FetchData(fwd_inf_vec_.at(lyr)); + } + } + // Set output data memory for the last layer. + src_state += cell_bytes; + if (src_state_cell) src_state_cell += cell_bytes; + if (dst_state) dst_state += cell_bytes; + if (dst_state_cell) dst_state_cell += cell_bytes; + fwd_inf_vec_.back().SetNewDataMem(this->dst_.back()->get_data_handle(), + src_state, src_state_cell, dst, dst_state, dst_state_cell, data_dtype); + if (is_training) { + fwd_trn_vec_.back().FetchData(fwd_inf_vec_.back()); + } + } + if (is_training) { + for (auto& trn_lyr : fwd_trn_vec_) RegisterMKLDNNRnn(trn_lyr); + } else { + for (auto& inf_lyr : fwd_inf_vec_) RegisterMKLDNNRnn(inf_lyr); + } + MKLDNNStream::Get()->Submit(); +} + +void MKLDNNRnnOp::Backward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using tag = mkldnn::memory::format_tag; + const RNNParam& default_param = full_param_.default_param; + if (kAddTo == req[rnn_enum::kData] || kAddTo == req[rnn_enum::kParams]) + LOG(FATAL) << "Currently, `add` operations against gradients of input and weights" + << " are not supported by RNNs."; + if (default_param.state_outputs) { + if (kAddTo == req[rnn_enum::kStateOut]) + LOG(FATAL) << "Currently, `add` operation against gradients of begining state" + << " is not supported by RNNs."; + if (default_param.mode == rnn_enum::kLstm && req[rnn_enum::kStateCell]) + LOG(FATAL) << "Currently, `add` operation against gradients of begining cell-state" + << " is not supported by LSTM."; + } + // Initialize the bwd_vec_ + if (bwd_vec_.size() != fwd_inf_vec_.size()) { + bwd_vec_.clear(); + for (size_t lyr = 0; lyr < fwd_inf_vec_.size(); ++lyr) + bwd_vec_.emplace_back(fwd_trn_vec_.at(lyr), inputs[rnn_enum::kData], + inputs[rnn_enum::kParams]); + } + // Fetch weights, src and dst from Forward layer + if (bwd_vec_.size() != fwd_trn_vec_.size()) + LOG(FATAL) << "MKL-DNN RNN fusion error."; + for (size_t lyr = 0; lyr < bwd_vec_.size(); ++lyr) { + bwd_vec_.at(lyr).FetchDataWeightsMem(fwd_trn_vec_.at(lyr)); + bwd_vec_.at(lyr).SetWeightsGradsMem(); + } + + const int data_dtype = inputs[rnn_enum::kData].dtype(); + const int w_dtype = inputs[rnn_enum::kParams].dtype(); + const size_t w_bytes = mshadow::mshadow_sizeof(w_dtype); + // index description of outputs NDArray + // 0 1 2 3 + // | dx | dw | dhx | dcx| + char* dx = req[rnn_enum::kData] == kNullOp ? nullptr + : static_cast(outputs[rnn_enum::kData].data().dptr_); + char* dw = static_cast(outputs[rnn_enum::kParams].data().dptr_); + char* db = dw + (inputs[rnn_enum::kParams].data().Size() - + GetRnnBiasSize(default_param.num_layers, default_param.state_size, + default_param.bidirectional + 1, default_param.mode)) * w_bytes; + char* dhx = req[rnn_enum::kState] == kNullOp ? nullptr + : static_cast(outputs[rnn_enum::kState].data().dptr_); + char* dcx = nullptr; + if (full_param_.default_param.mode == rnn_enum::kLstm + && req[rnn_enum::kStateCell] != kNullOp) + dcx = static_cast(outputs[rnn_enum::kStateCell].data().dptr_); + + // index description of inputs NDArray + // 0 1 2 3 4 5 6 7 8 9 + // | x | w | hx | y | dy | hy | dhy | cx | cy | dcy | + char* dy = static_cast(inputs[4].data().dptr_); + char* dhy = nullptr; + if (default_param.state_outputs) + dhy = static_cast(inputs[6].data().dptr_); + + char* dcy = nullptr; + if ((default_param.mode == rnn_enum::kLstm) && default_param.state_outputs) + dcy = static_cast(inputs[9].data().dptr_); + + if (bwd_vec_.size() == 1) { + bwd_vec_.back().SetDataGradsMem(dx, dhx, dcx, dy, dhy, dcy, data_dtype); + RegisterMKLDNNRnn(bwd_vec_.back()); + } else { + const size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ * + default_param.state_size * mshadow::mshadow_sizeof(data_dtype); + if (diff_src == nullptr) { + auto desc = mkldnn::memory::desc(full_param_.layer_params.back().src_dims, + get_mkldnn_type(data_dtype), tag::tnc); + diff_src = std::make_shared(desc, CpuEngine::Get()->get_engine()); + } + // Sets primitives from bottom to top, then submits them in reversed order. + bwd_vec_.front().SetDataGradsMem(dx, dhx, dcx, + diff_src->get_data_handle(), dhy, dcy, data_dtype); + for (size_t lyr = 1; lyr < bwd_vec_.size() - 1; ++lyr) { + if (dhx) dhx += cell_bytes; + if (dcx) dcx += cell_bytes; + if (dhy) dhy += cell_bytes; + if (dcy) dcy += cell_bytes; + bwd_vec_.at(lyr).SetDataGradsMem(diff_src->get_data_handle(), dhx, dcx, + diff_src->get_data_handle(), dhy, dcy, data_dtype); + } + if (dhx) dhx += cell_bytes; + if (dcx) dcx += cell_bytes; + if (dhy) dhy += cell_bytes; + if (dcy) dcy += cell_bytes; + bwd_vec_.back().SetDataGradsMem(diff_src->get_data_handle(), dhx, dcx, + dy, dhy, dcy, data_dtype); + + for (std::vector::const_reverse_iterator bwd = bwd_vec_.rbegin(); + bwd != bwd_vec_.rend(); ++bwd) { + RegisterMKLDNNRnn(*bwd); + } + } + MKLDNNStream::Get()->Submit(); + + // Commit weights diff + if (req[rnn_enum::kParams] != kNullOp) { + for (size_t lyr = 0; lyr < bwd_vec_.size(); ++lyr) { + bwd_vec_.at(lyr).CommitWeightsDiff(dw, db, w_dtype); + dw += full_param_.layer_params.at(lyr).single_w_size * w_bytes; + db += full_param_.layer_params.at(lyr).single_b_size * w_bytes; + } + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h deleted file mode 100644 index ea8e07ea617c..000000000000 --- a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h +++ /dev/null @@ -1,740 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_ -#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_ -#if MXNET_USE_MKLDNN == 1 -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "../../math_functions-inl.h" -#include "../../operator_common.h" -#include "../../rnn_impl.h" -#include "../../rnn-inl.h" -#include "mkldnn.hpp" -#include "./mkldnn_base-inl.h" - -namespace mxnet { -namespace op { - -static algorithm GetMKLDNNRNNAlgo(int mode, - int* ngates, - int* nstates) { - algorithm algo = algorithm::vanilla_rnn; - switch (mode) { - case rnn_enum::kLstm: - *ngates = 4; - *nstates = 2; - algo = algorithm::vanilla_lstm; - break; - case rnn_enum::kGru: - *ngates = 3; - *nstates = 1; - algo = algorithm::vanilla_gru; - break; - case rnn_enum::kRnnRelu: - case rnn_enum::kRnnTanh: - *ngates = 1; - *nstates = 1; - algo = algorithm::vanilla_rnn; - break; - default: - LOG(FATAL) << "unsupported RNN mode:" << mode; - break; - } - return algo; -} - -static void ConcatData(mkldnn::memory::format src_format, - mkldnn::memory::format dst_format, - std::vector srcs_cds, - mkldnn::memory::dims dst_cds, - mkldnn::memory::data_type mkldnn_dtype, - int concat_dimension, - std::vector srcs_data, - const mkldnn::memory &dst) { - auto cpu_engine = CpuEngine::Get()->get_engine(); - std::vector srcs_pd; - std::vector srcs; - for (size_t i = 0; i < srcs_cds.size(); i++) { - auto desc = mkldnn::memory::desc(srcs_cds[i], mkldnn_dtype, src_format); - auto mpd = mkldnn::memory::primitive_desc(desc, cpu_engine); - auto src_memory = mkldnn::memory(mpd, srcs_data[i]); - srcs_pd.push_back(mpd); - srcs.push_back(src_memory); - } - std::vector inputs; - for (size_t i = 0; i < srcs_cds.size(); i++) { - inputs.push_back(srcs[i]); - } - auto dst_desc = mkldnn::memory::desc(dst_cds, mkldnn_dtype, dst_format); - auto concat_pd = concat::primitive_desc(dst_desc, concat_dimension, srcs_pd); - MKLDNNStream::Get()->RegisterPrim(concat(concat_pd, inputs, dst)); - MKLDNNStream::Get()->Submit(); -} - -// cached mkldnn memory -// first layer wx, wh with next L - 1 layers wx and wh -// with L layers hx and cx, src and dst data/iter etc. -// it will prepare memory on before and after reorder and concat. -// for unidirectional, it will fused as dim like 1 + (L - 1) when I != H. -// for bidirectional, it will fused as data + back_data (weight, bias, iter etc), -// also need to identify first layer and next layers -static size_t GetMKLDNNRNNCacheMemorySize(int L, - int D, - int T, - int N, - int I, - int H, - int mode) { - size_t size = 0; - switch (mode) { - case rnn_enum::kLstm: - size = 2 * (D * (I + H) * 4 * H + (L - 1) * D * (D * H + H) * 4 * H + - L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 4 * H + (L + 2) * D * 2 * N * H + - 6 * D * (I + H + 2) * 4 * H + T * N * I * 2; - break; - case rnn_enum::kGru: - size = 2 * (D * (I + H) * 3 * H + (L - 1) * D * (D * H + H) * 3 * H + - L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 3 * H + (L + 2) * D * 2 * N * H + - 6 * D * (I + H + 2) * 3 * H + T * N * I * 2; - break; - case rnn_enum::kRnnRelu: - case rnn_enum::kRnnTanh: - size = 2 * (D * (I + H) * 1 * H + (L - 1) * D * (D * H + H) * 1 * H + - L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 1 * H + (L + 2) * D * 2 * N * H + - 6 * D * (I + H + 2) * 1 * H + T * N * I * 2; - break; - default: - LOG(FATAL) << "unknown RNN mode " << mode; - break; - } - return size; -} - -template -static void AdjustGruWeightGateOrder(DType* weight, - const int I, - const int H) { - // mxnet gru gate order is reset, update and new gates - // mkldnn gru gate order is update, reset and new gates - const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - DType* weight_reset = weight; - DType* weight_update = weight + I * H; - #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < I * H; i++) { - DType tmp = weight_update[i]; - weight_update[i] = weight_reset[i]; - weight_reset[i] = tmp; - } -} - -template -static void AdjustGruBiasGateOrder(DType* bias, - const int H) { - // mxnet gru gate order is reset, update and new gates - // mkldnn gru gate order is update, reset and new gates - const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - DType* bias_reset = bias; - DType* bias_update = bias + H; - #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < H; i++) { - DType tmp = bias_update[i]; - bias_update[i] = bias_reset[i]; - bias_reset[i] = tmp; - } -} -// since there is different sematics of MKLDNN's Fused RNN and MXNet FusedRNN, -// bidirectional will be fused layer by layer, -// unidirectional will be done by fused 1 + fused (L - 1) layers or fused L layers(when I = H) - -template -static void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, - const int T, - const int N, - const int I, - const int H, - DType* x_ptr, - mkldnn::memory *user_src_layer_memory, - DType* hx_ptr, - DType* cx_ptr, - DType* w_ptr, - DType* b_ptr, - DType* y_ptr, - DType* hy_ptr, - DType* cy_ptr, - std::vector *concat_weight_memory, - std::vector *concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, - std::vector *rnn_forward_prim, - int layer_index, - bool *has_cache, - int lvalue, - int dtype, - bool is_train, - int mode) { - int ngates = 0, nstates = 0; - algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates); - mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); - const int single_cell_size = N * H; - const int single_b_size = ngates * H; - DType* wx = w_ptr; // ngates * H, I - DType* wh = w_ptr + I * H * ngates; // ngates * H, H - DType* back_wx = w_ptr + ngates * H * (I + H); - DType* back_wh = back_wx + I * H * ngates; - DType* bx = b_ptr; - DType* bh = b_ptr + H * ngates; - DType* back_bx = b_ptr + single_b_size * 2; - DType* back_bh = back_bx + H * ngates; - const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - auto cpu_engine = CpuEngine::Get()->get_engine(); - auto null_memory_ = null_memory(cpu_engine); - int offset1 = 0, offset2 = 0; - bool initialized = *has_cache; - mkldnn::memory::dims src_layer_tz = {T, N, I}; - mkldnn::memory::dims dst_layer_tz = {T, N, 2 * H}; - mkldnn::memory::dims weights_layer_tz = {1, 2, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo for reorder - mkldnn::memory::dims weights_iter_tz = {1, 2, H, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo for reorder - mkldnn::memory::dims bias_tz = {1, 2, ngates, H}; - mkldnn::memory::dims src_iter_tz = {1, 2, nstates, N, H}; // ldsnc - mkldnn::memory::dims dst_iter_tz = {1, 2, nstates, N, H}; // ldsnc - - if (!initialized) { - if (mode == rnn_enum::kGru) { - AdjustGruWeightGateOrder(wx, I, H); - AdjustGruWeightGateOrder(back_wx, I, H); - AdjustGruWeightGateOrder(wh, H, H); - AdjustGruWeightGateOrder(back_wh, H, H); - AdjustGruBiasGateOrder(bx, H); - AdjustGruBiasGateOrder(back_bx, H); - AdjustGruBiasGateOrder(bh, H); - AdjustGruBiasGateOrder(back_bh, H); - } - auto src_wx = (*concat_weight_memory)[2 * layer_index]; - auto src_wh = (*concat_weight_memory)[2 * layer_index + 1]; - std::vector srcs_data1; - srcs_data1.push_back(wx); - srcs_data1.push_back(back_wx); - ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, - {weights_layer_r_tz, weights_layer_r_tz}, weights_layer_tz, - mkldnn_dtype, 1, srcs_data1, src_wx); - srcs_data1.clear(); - srcs_data1.push_back(wh); - srcs_data1.push_back(back_wh); - ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, - {weights_iter_r_tz, weights_iter_r_tz}, weights_iter_tz, - mkldnn_dtype, 1, srcs_data1, src_wh); - int tmpvalue = 0; - if (lvalue > 0) { - tmpvalue = lvalue + 1; - } - MKLDNNStream::Get()->RegisterPrim(reorder(src_wx, (*wx_memory)[tmpvalue])); - MKLDNNStream::Get()->RegisterPrim(reorder(src_wh, (*wh_memory)[tmpvalue])); - - DType* user_bias = reinterpret_cast - ((*bias_memory)[tmpvalue].get_data_handle()); - #pragma omp parallel for num_threads(omp_threads) - for (int j = 0; j < single_b_size; j++) { - user_bias[j] = bx[j] + bh[j]; - user_bias[single_b_size + j] = back_bx[j] + back_bh[j]; - } - } - if (lvalue > 0) { - (*wx_memory)[layer_index].set_data_handle((*wx_memory)[lvalue + 1].get_data_handle()); - (*wh_memory)[layer_index].set_data_handle((*wh_memory)[lvalue + 1].get_data_handle()); - (*bias_memory)[layer_index].set_data_handle((*bias_memory)[lvalue + 1].get_data_handle()); - } - - auto src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto dst_layer_md = mkldnn::memory::desc( - { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto dst_iter_md = mkldnn::memory::desc( - { dst_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - auto src_iter_md = mkldnn::memory::desc( - {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); - auto bias_md = mkldnn::memory::desc({bias_tz}, - mkldnn_dtype, mkldnn::memory::format::ldgo); - - auto user_src_iter_memory = (*concat_iter_memory)[2]; - if (mode == rnn_enum::kLstm) { - std::vector srcs_data1; - srcs_data1.push_back(hx_ptr); - srcs_data1.push_back(cx_ptr); - auto tmp1_src_iter_memory = (*concat_iter_memory)[0]; - ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, - {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2, - srcs_data1, tmp1_src_iter_memory); - std::vector srcs_data2; - srcs_data2.push_back(hx_ptr + single_cell_size); - srcs_data2.push_back(cx_ptr + single_cell_size); - auto tmp2_src_iter_memory = (*concat_iter_memory)[1]; - ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, - {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2, - srcs_data2, tmp2_src_iter_memory); - std::vector srcs_data3; - srcs_data3.push_back(reinterpret_cast(tmp1_src_iter_memory.get_data_handle())); - srcs_data3.push_back(reinterpret_cast(tmp2_src_iter_memory.get_data_handle())); - ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, - {{1, 1, nstates, N, H}, {1, 1, nstates, N, H}}, {1, 2, nstates, N, H}, - mkldnn_dtype, 1, srcs_data3, user_src_iter_memory); - } else { - user_src_iter_memory.set_data_handle(hx_ptr); - } - (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle()); - - rnn_cell::desc rnn_cell(nalgorithm, - mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh); - - rnn_forward::desc layer_desc(prop_kind::forward_inference, rnn_cell, - rnn_direction::bidirectional_concat, src_layer_md, - src_iter_md, weight_layer_md, weight_iter_md, - bias_md, dst_layer_md, dst_iter_md); - - auto prim_desc - = rnn_forward::primitive_desc(layer_desc, cpu_engine); - - if (x_ptr && layer_index == 0) { - (*x_memory)[layer_index].set_data_handle(x_ptr); - } else { - (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle()); - } - (*y_memory)[layer_index].set_data_handle(y_ptr); - - if (rnn_forward_prim->size() <= (size_t)layer_index) { - primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index], - (*hcx_memory)[layer_index], (*wx_memory)[layer_index], - (*wh_memory)[layer_index], (*bias_memory)[layer_index], - (*y_memory)[layer_index], - (*hcy_memory)[layer_index], null_memory_); - rnn_forward_prim->push_back(rnn_prim); - } - MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]); - MKLDNNStream::Get()->Submit(); - - if (state_outputs) { - DType* dst_hcy = reinterpret_cast ((*hcy_memory)[layer_index].get_data_handle()); - if (mode == rnn_enum::kLstm) { - offset1 = nstates * single_cell_size; - offset2 = (nstates + 1) * single_cell_size; - #pragma omp parallel for num_threads(omp_threads) - for (int n = 0; n < single_cell_size; n++) { - hy_ptr[n] = dst_hcy[n]; - hy_ptr[n + single_cell_size] = dst_hcy[n + offset1]; - cy_ptr[n] = dst_hcy[n + single_cell_size]; - cy_ptr[n + single_cell_size] = dst_hcy[n + offset2]; - } - } else { - #pragma omp parallel for num_threads(omp_threads) - for (int n = 0; n < 2 * single_cell_size; n++) { - hy_ptr[n] = dst_hcy[n]; - } - } - } -} - - -template -static void MKLDNNRNNForwardUnidi(bool state_outputs, - const int L, - const int T, - const int N, - const int I, - const int H, - DType* x_ptr, - mkldnn::memory *user_src_layer_memory, - DType* hx_ptr, - DType* cx_ptr, - DType* w_ptr, - DType* b_ptr, - DType* y_ptr, - DType* hy_ptr, - DType* cy_ptr, - std::vector *concat_weight_memory, - std::vector *concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, - std::vector *rnn_forward_prim, - int layer_index, - bool *has_cache, - int dtype, - bool is_train, - int mode) { - int ngates = 0, nstates = 0; - algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates); - mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); - const int cell_size = N * H; - const int single_cell_size = N * H; - const int single_b_size = ngates * H; - int w_size = (I + H) * H * ngates; - const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - auto cpu_engine = CpuEngine::Get()->get_engine(); - auto null_memory_ = null_memory(cpu_engine); - int offset1 = 0, offset2 = 0; - bool initialized = *has_cache; - - mkldnn::memory::dims src_layer_tz = {T, N, I}; - mkldnn::memory::dims dst_layer_tz = {T, N, H}; - mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; - mkldnn::memory::dims src_iter_tz = {L, 1, nstates, N, H}; // ldsnc - mkldnn::memory::dims dst_iter_tz = {L, 1, nstates, N, H}; // ldsnc - mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo for reorder - mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo for reorder - - auto weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto dst_layer_md = mkldnn::memory::desc( - {dst_layer_tz}, mkldnn_dtype, mkldnn::memory::format::tnc); - auto src_iter_md = mkldnn::memory::desc( - {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); - auto bias_md = mkldnn::memory::desc({bias_tz}, - mkldnn_dtype, mkldnn::memory::format::ldgo); - auto dst_iter_md = mkldnn::memory::desc( - {dst_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); - - for (int l = 0; l < L; l++) { - if (mode == rnn_enum::kLstm) { - std::vector srcs_data; - srcs_data.push_back(hx_ptr); - srcs_data.push_back(cx_ptr); - auto tmp_src_iter_memory = (*concat_iter_memory)[l + layer_index]; - ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, - {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, - 2, srcs_data, tmp_src_iter_memory); - } else { - (*concat_iter_memory)[l + layer_index].set_data_handle(hx_ptr); - } - hx_ptr += cell_size; - if (mode == rnn_enum::kLstm) { - cx_ptr += cell_size; - } - } - - auto user_src_iter_memory = null_memory_; - if (L == 1) { - user_src_iter_memory = (*concat_iter_memory)[layer_index]; - } else { - user_src_iter_memory = (*concat_iter_memory)[L + layer_index]; - std::vector src_l_data; - std::vector src_l_dim; - for (int l = 0; l < L; l++) { - src_l_data.push_back(reinterpret_cast - ((*concat_iter_memory)[l + layer_index].get_data_handle())); - src_l_dim.push_back({1, 1, nstates, N, H}); - } - ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, src_l_dim, - {L, 1, nstates, N, H}, mkldnn_dtype, 0, src_l_data, user_src_iter_memory); - } - (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle()); - - auto src_wx_f = (*concat_weight_memory)[2 * layer_index]; - auto src_wh_f = (*concat_weight_memory)[2 * layer_index + 1]; - - std::vector srcs_data_x; - std::vector srcs_data_h; - std::vector src_l_dim_x; - std::vector src_l_dim_h; - if (!initialized) { - if (L == 1) { - DType* wx = w_ptr; - DType* wh = w_ptr + I * H * ngates; - if (mode == rnn_enum::kGru) { - AdjustGruWeightGateOrder(wx, I, H); - AdjustGruWeightGateOrder(wh, H, H); - AdjustGruBiasGateOrder(b_ptr, H); - AdjustGruBiasGateOrder(b_ptr + H * ngates, H); - } - src_wx_f.set_data_handle(wx); - src_wh_f.set_data_handle(wh); - } else { - for (int l = 0; l < L; l++) { - DType* wx = w_ptr; - DType* wh = w_ptr + I * H * ngates; - DType* bx = b_ptr + l * ngates * H * 2; - DType* bh = b_ptr + l * ngates * H * 2 + H * ngates; - if (mode == rnn_enum::kGru) { - AdjustGruWeightGateOrder(wx, I, H); - AdjustGruWeightGateOrder(wh, H, H); - AdjustGruBiasGateOrder(bx, H); - AdjustGruBiasGateOrder(bh, H); - } - srcs_data_x.push_back(wx); - srcs_data_h.push_back(wh); - src_l_dim_x.push_back(weights_layer_r_tz); - src_l_dim_h.push_back(weights_iter_r_tz); - w_ptr = w_ptr + w_size; - } - ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, - src_l_dim_x, weights_layer_tz, mkldnn_dtype, 0, srcs_data_x, src_wx_f); - ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, - src_l_dim_h, weights_iter_tz, mkldnn_dtype, 0, srcs_data_h, src_wh_f); - } - MKLDNNStream::Get()->RegisterPrim(reorder(src_wx_f, (*wx_memory)[layer_index])); - MKLDNNStream::Get()->RegisterPrim(reorder(src_wh_f, (*wh_memory)[layer_index])); - - DType* user_bias_f = reinterpret_cast ((*bias_memory)[layer_index].get_data_handle()); - #pragma omp parallel for num_threads(omp_threads) - for (int j = 0; j < L * single_b_size; j++) { - int k = j / single_b_size; - user_bias_f[j] = b_ptr[j + k * single_b_size] + b_ptr[j + k * single_b_size + single_b_size]; - } - } - - rnn_cell::desc rnn_cell(nalgorithm, - mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh); - - rnn_forward::desc layer_desc(prop_kind::forward_inference, rnn_cell, - rnn_direction::unidirectional, src_layer_md, - src_iter_md, weight_layer_md, weight_iter_md, - bias_md, dst_layer_md, dst_iter_md); - - auto prim_desc - = rnn_forward::primitive_desc(layer_desc, cpu_engine); - - if (x_ptr && layer_index == 0) { - (*x_memory)[layer_index].set_data_handle(x_ptr); - } else { - (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle()); - } - (*y_memory)[layer_index].set_data_handle(y_ptr); - - if (rnn_forward_prim->size() <= (size_t)layer_index) { - primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index], - (*hcx_memory)[layer_index], (*wx_memory)[layer_index], - (*wh_memory)[layer_index], (*bias_memory)[layer_index], - (*y_memory)[layer_index], - (*hcy_memory)[layer_index], null_memory_); - rnn_forward_prim->push_back(rnn_prim); - } - MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]); - MKLDNNStream::Get()->Submit(); - - if (state_outputs) { - DType* dst_hcy = reinterpret_cast ((*hcy_memory)[layer_index].get_data_handle()); - if (mode == rnn_enum::kLstm) { - for (int l = 0; l < L; l++) { - offset1 = l * single_cell_size; - offset2 = l * nstates * single_cell_size; - #pragma omp parallel for num_threads(omp_threads) - for (int n = 0; n < single_cell_size; n++) { - hy_ptr[offset1 + n] = dst_hcy[offset2 + n]; - cy_ptr[offset1 + n] = dst_hcy[offset2 + n + single_cell_size]; - } - } - } else { - #pragma omp parallel for num_threads(omp_threads) - for (int n = 0; n < L * single_cell_size; n++) { - hy_ptr[n] = dst_hcy[n]; - } - } - } -} - -template -static void MKLDNNRNNForward(bool state_outputs, - const int L, - const int D, - const int T, - const int N, - const int I, - const int H, - DType* x_ptr, - DType* hx_ptr, - DType* cx_ptr, - DType* w_ptr, - DType* b_ptr, - DType* y_ptr, - DType* hy_ptr, - DType* cy_ptr, - std::vector *concat_weight_memory, - std::vector *concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, - std::vector *rnn_forward_prim, - bool *has_cache, - int dtype, - bool is_train, - int mode) { - int ngates = 0, nstates = 0; - GetMKLDNNRNNAlgo(mode, &ngates, &nstates); - const int b_size = 2 * H * ngates * D; - const int cell_size = N * H * D; - // First layer - int w_size = (I + H) * H * ngates * D; - auto cpu_engine = CpuEngine::Get()->get_engine(); - auto null_memory_ = null_memory(cpu_engine); - DType* tmpNull = NULL; - // when D = 1 and I == H, L layers can be fused together - if (D == 1 && I == H) { - MKLDNNRNNForwardUnidi(state_outputs, L, T, N, I, H, x_ptr, &null_memory_, - hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, - concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, - bias_memory, y_memory, hcy_memory, rnn_forward_prim, - 0, has_cache, dtype, is_train, mode); - } else { - auto user_src_layer_memory_l = null_memory_; - if (D == 2) { - MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, I, H, x_ptr, &user_src_layer_memory_l, - hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, - concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, - bias_memory, y_memory, hcy_memory, rnn_forward_prim, - 0, has_cache, 0, dtype, is_train, mode); - } else { - MKLDNNRNNForwardUnidi(state_outputs, 1, T, N, I, H, x_ptr, &user_src_layer_memory_l, - hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, - concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, - bias_memory, y_memory, hcy_memory, rnn_forward_prim, - 0, has_cache, dtype, is_train, mode); - } - if (L > 1) { - user_src_layer_memory_l = (*y_memory)[0]; - // go to next L - 1 layers. - // If D = 2, do it layer by layer. If D = 1, fused L - 1 layers - w_ptr += w_size; - b_ptr += b_size; - if (D == 2) { - w_size = (H * D + H) * H * ngates * D; - for (int l = 0; l < L - 1; l++) { - if (state_outputs) { - hy_ptr += cell_size; - if (mode == rnn_enum::kLstm) { - cy_ptr += cell_size; - } - } - hx_ptr += cell_size; - if (mode == rnn_enum::kLstm) { - cx_ptr += cell_size; - } - MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, D * H, H, tmpNull, - &user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, - cy_ptr, concat_weight_memory, concat_iter_memory, x_memory, - hcx_memory, wx_memory, wh_memory, bias_memory, - y_memory, hcy_memory, rnn_forward_prim, - 1, has_cache, l + 1, dtype, is_train, mode); - user_src_layer_memory_l = (*y_memory)[1]; - w_ptr += w_size; - b_ptr += b_size; - } - } - if (D == 1) { - if (state_outputs) { - hy_ptr += cell_size; - if (mode == rnn_enum::kLstm) { - cy_ptr += cell_size; - } - } - w_size = (H + H) * H * ngates; - MKLDNNRNNForwardUnidi(state_outputs, L - 1, T, N, H, H, tmpNull, &user_src_layer_memory_l, - hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, - concat_iter_memory, x_memory, hcx_memory, wx_memory, - wh_memory, bias_memory, y_memory, hcy_memory, - rnn_forward_prim, 1, has_cache, dtype, is_train, mode); - } - } - } - *has_cache = true; -} - -template -static void MKLDNNRNNForwardInference(bool state_outputs, - const int num_layers, - const int direction, - const int seq_length, - const int batch_size, - const int input_size, - const int state_size, - DType* x_ptr, - DType* hx_ptr, - DType* cx_ptr, - DType* w_ptr, - DType* b_ptr, - DType* y_ptr, - DType* hy_ptr, - DType* cy_ptr, - std::vector* concat_weight_memory, - std::vector* concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, - std::vector *rnn_forward_prim, - bool *has_cache, - int dtype, - bool is_train, - int mode) { - switch (mode) { - case rnn_enum::kLstm: - case rnn_enum::kGru: - case rnn_enum::kRnnTanh: - case rnn_enum::kRnnRelu: - MKLDNNRNNForward(state_outputs, num_layers, direction, seq_length, - batch_size, input_size, state_size, x_ptr, hx_ptr, - cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, - concat_weight_memory, concat_iter_memory, x_memory, - hcx_memory, wx_memory, wh_memory, - bias_memory, y_memory, hcy_memory, rnn_forward_prim, - has_cache, dtype, is_train, mode); - break; - default: - LOG(FATAL) << "unknown RNN mode" << mode; - break; - } -} - -} // namespace op -} // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 -#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_slice-inl.h b/src/operator/nn/mkldnn/mkldnn_slice-inl.h index f41db01a9837..e6258c8c3f43 100644 --- a/src/operator/nn/mkldnn/mkldnn_slice-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_slice-inl.h @@ -45,7 +45,7 @@ class MKLDNNSliceFwd { const NDArray &in, const NDArray &out); void SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output); - const mkldnn::reorder &GetPd() const; + void Register(); private: std::shared_ptr data_; diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc b/src/operator/nn/mkldnn/mkldnn_slice.cc index 2a817a25a5b8..dba10f8b6cd5 100644 --- a/src/operator/nn/mkldnn/mkldnn_slice.cc +++ b/src/operator/nn/mkldnn/mkldnn_slice.cc @@ -49,13 +49,15 @@ MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam ¶m, dims[i] = oshape[i]; offsets[i] = s; } - auto in_mem_pd = in.GetMKLDNNData()->get_primitive_desc(); - auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc(); - auto view_pd = mkldnn::view::primitive_desc(in_mem_pd, dims, offsets); - auto reorder_pd = reorder::primitive_desc(view_pd.dst_primitive_desc(), out_mem_pd); - this->data_ = std::make_shared(view_pd.dst_primitive_desc(), nullptr); - this->out_ = std::make_shared(view_pd.dst_primitive_desc(), nullptr); - this->fwd_ = std::make_shared(reorder_pd, *this->data_, *this->out_); + + auto in_md = in.GetMKLDNNData()->get_desc(); + auto out_md = out.GetMKLDNNData()->get_desc(); + auto sub_md = in_md.submemory_desc(dims, offsets); + + auto engine = CpuEngine::Get()->get_engine(); + this->data_ = std::make_shared(sub_md, engine, nullptr); + this->out_ = std::make_shared(out_md, engine, nullptr); + this->fwd_ = std::make_shared(*this->data_, *this->out_); } void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output) { @@ -63,8 +65,9 @@ void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory this->out_->set_data_handle(output.get_data_handle()); } -const mkldnn::reorder &MKLDNNSliceFwd::GetPd() const { - return *fwd_; +void MKLDNNSliceFwd::Register() { + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, + {{MKLDNN_ARG_FROM, *(this->data_)}, {MKLDNN_ARG_TO, *(this->out_)}}); } MKLDNNSliceFwd &GetSliceForward(const SliceParam ¶m, const bool is_train, @@ -91,10 +94,10 @@ void MKLDNNSlice(const SliceParam ¶m, const OpContext& ctx, const NDArray &in, OpReqType req, const NDArray &out) { MKLDNNSliceFwd &fwd = GetSliceForward(param, ctx.is_train, in, out); auto in_mem = in.GetMKLDNNData(); - auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc(); - auto out_mem = CreateMKLDNNMem(out, out_mem_pd, req); + auto out_md = out.GetMKLDNNData()->get_desc(); + auto out_mem = CreateMKLDNNMem(out, out_md, req); fwd.SetNewMem(*in_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetPd()); + fwd.Register(); CommitOutput(out, out_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/nn/mkldnn/mkldnn_softmax.cc b/src/operator/nn/mkldnn/mkldnn_softmax.cc index 77ab43b63fd5..5b43cb0b0864 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax.cc @@ -31,13 +31,26 @@ namespace mxnet { namespace op { +static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd( + bool is_train, const int axis, + const mkldnn::memory &input_mem) { + mkldnn::memory::desc data_md = input_mem.get_desc(); + auto cpu_engine = CpuEngine::Get()->get_engine(); + auto prop = is_train ? mkldnn::prop_kind::forward_training + : mkldnn::prop_kind::forward_scoring; + auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis); + return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine); +} + + bool SupportMKLDNNSoftmax(const SoftmaxParam ¶m, const NDArray &data, const NDArray &output) { + // MKLDNN does not support temperature argument in their softmax function + // now. Need update this once they start to support it. const int ndim = data.shape().ndim(); const int in_dtype = data.dtype(); const int out_dtype = output.dtype(); - const int axis = CheckAxis(param.axis, ndim); // MKLDNN does not support temperature argument in their softmax function // now. Need update this once they start to support it. @@ -48,21 +61,12 @@ bool SupportMKLDNNSoftmax(const SoftmaxParam ¶m, axis != (ndim - 1)) { return false; } + // only supports ndim = 1, 2, 3, 4 for now return (ndim >= 1 && ndim <= 4); } -static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(const int axis, - const bool is_train, - const mkldnn::memory &input) { - auto data_md = input.get_primitive_desc().desc(); - auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; - auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis); - auto pd = mkldnn::softmax_forward::primitive_desc(desc, CpuEngine::Get()->get_engine()); - return pd; -} - -void MKLDNNSoftmaxForward(const nnvm::NodeAttrs &attrs, +void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, @@ -71,21 +75,23 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs &attrs, // same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now. CHECK_NE(req, kAddTo); const SoftmaxParam& param = nnvm::get(attrs.parsed); - const int axis = CheckAxis(param.axis, in_data.shape().ndim()); - + int axis = CheckAxis(param.axis, in_data.shape().ndim()); NDArray data = in_data; if (in_data.IsView() && in_data.IsMKLDNNData()) { data = in_data.Reorder2Default(); } auto data_mem = data.GetMKLDNNData(); - auto pd = GetSoftmaxFwdPd(axis, ctx.is_train, *data_mem); - auto out_mem = CreateMKLDNNMem(out_data, pd.dst_primitive_desc(), req); + auto pd = GetSoftmaxFwdPd(ctx.is_train, axis, *data_mem); + auto out_mem = CreateMKLDNNMem(out_data, pd.dst_desc(), req); MKLDNNStream *stream = MKLDNNStream::Get(); - stream->RegisterPrim(mkldnn::softmax_forward(pd, *data_mem, *out_mem.second)); + stream->RegisterPrimArgs(pd, + {{MKLDNN_ARG_SRC, *data_mem}, {MKLDNN_ARG_DST, *out_mem.second}}); CommitOutput(out_data, out_mem); stream->Submit(); } + } // namespace op } // namespace mxnet #endif + diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc index ae34fe633d6f..dbd3abf2276d 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -27,15 +27,13 @@ #include "../../softmax_output-inl.h" #include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" - namespace mxnet { namespace op { static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl( const SoftmaxOutputParam& param, bool is_train, const int axis, const mkldnn::memory &input_mem) { - mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); - mkldnn::memory::desc data_md = data_mpd.desc(); + mkldnn::memory::desc data_md = input_mem.get_desc(); auto cpu_engine = CpuEngine::Get()->get_engine(); auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; @@ -47,8 +45,6 @@ typedef ParamOpSign MKLDNNSoftmaxOuputSignature; class MKLDNNSoftmaxOutputFwd { std::shared_ptr fwd_; - std::shared_ptr data_; - std::shared_ptr out_; public: const mkldnn::softmax_forward::primitive_desc fwd_pd; @@ -56,29 +52,10 @@ class MKLDNNSoftmaxOutputFwd { MKLDNNSoftmaxOutputFwd(const SoftmaxOutputParam& param, bool is_train, const int axis, const mkldnn::memory &mem): fwd_pd( GetSoftmaxOutputFwdDescImpl(param, is_train, axis, mem)) { + fwd_ = std::make_shared(fwd_pd); } - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { - if (this->data_ == nullptr) - this->data_ = std::shared_ptr(new mkldnn::memory( - data.get_primitive_desc(), data.get_data_handle())); - else - this->data_->set_data_handle(data.get_data_handle()); - - if (this->out_ == nullptr) - this->out_ = std::shared_ptr(new mkldnn::memory( - output.get_primitive_desc(), output.get_data_handle())); - else - this->out_->set_data_handle(output.get_data_handle()); - - if (this->fwd_ == nullptr) { - this->fwd_ = std::shared_ptr( - new mkldnn::softmax_forward(fwd_pd, mkldnn::primitive::at(*this->data_), - *this->out_)); - } - } - - const mkldnn::softmax_forward &GetFwd() const { + const inline mkldnn::softmax_forward &GetFwd() const { return *fwd_; } }; @@ -129,17 +106,17 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, auto input_mem = idata.GetMKLDNNData(); auto out_mem = CreateMKLDNNMem(out_data[softmaxout_enum::kOut], - input_mem->get_primitive_desc(), req[softmaxout_enum::kOut]); + input_mem->get_desc(), req[softmaxout_enum::kOut]); MKLDNNSoftmaxOutputFwd &fwd = GetSoftmaxOutputForward(param, ctx, idata); - fwd.SetNewMem(*input_mem, *out_mem.second); MKLDNNStream *stream = MKLDNNStream::Get(); - stream->RegisterPrim(fwd.GetFwd()); - + stream->RegisterPrimArgs(fwd.GetFwd(), + {{MKLDNN_ARG_SRC, *input_mem}, {MKLDNN_ARG_DST, *out_mem.second}}); CommitOutput(out_data[softmaxout_enum::kOut], out_mem); stream->Submit(); } } // namespace op } // namespace mxnet #endif + diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc index 724b8a2613d6..5027bcbaabb1 100644 --- a/src/operator/nn/mkldnn/mkldnn_sum.cc +++ b/src/operator/nn/mkldnn/mkldnn_sum.cc @@ -28,35 +28,38 @@ #include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" -#if MXNET_USE_MKLDNN == 1 namespace mxnet { namespace op { -void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, - const mkldnn::memory &out) { - std::vector input_pds(2); +#if MXNET_USE_MKLDNN == 1 +void MKLDNNSum(const mkldnn::memory &arr1, + const mkldnn::memory &arr2, + const mkldnn::memory &out) { + std::vector input_pds(2); std::vector scales(2, 1); - std::vector inputs; - input_pds[0] = arr1.get_primitive_desc(); - input_pds[1] = arr2.get_primitive_desc(); + input_pds[0] = arr1.get_desc(); + input_pds[1] = arr2.get_desc(); CHECK(input_pds[0] == input_pds[0]); const mkldnn::memory *in_mem1 = &arr1; const mkldnn::memory *in_mem2 = &arr2; - auto output_pd = out.get_primitive_desc(); + auto output_pd = out.get_desc(); if (input_pds[0] != output_pd) { auto tmp_memory1 = TmpMemMgr::Get()->Alloc(output_pd); auto tmp_memory2 = TmpMemMgr::Get()->Alloc(output_pd); mxnet::MKLDNNCopy(arr1, tmp_memory1); mxnet::MKLDNNCopy(arr2, tmp_memory2); - input_pds[0] = tmp_memory1->get_primitive_desc(); - input_pds[1] = tmp_memory2->get_primitive_desc(); + input_pds[0] = tmp_memory1->get_desc(); + input_pds[1] = tmp_memory2->get_desc(); in_mem1 = tmp_memory1; in_mem2 = tmp_memory2; } - inputs.push_back(*in_mem1); - inputs.push_back(*in_mem2); - mkldnn::sum::primitive_desc sum_pd(scales, input_pds); - MKLDNNStream::Get()->RegisterPrim(mkldnn::sum(sum_pd, inputs, out)); + mkldnn::sum::primitive_desc sum_pd(output_pd, scales, input_pds, CpuEngine::Get()->get_engine()); + mkldnn_args_map_t args = { + { MKLDNN_ARG_MULTIPLE_SRC, *in_mem1 }, + { MKLDNN_ARG_MULTIPLE_SRC + 1, *in_mem2 }, + { MKLDNN_ARG_DST, out }, + }; + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::sum(sum_pd), args); } class MKLDNNSumFwd { @@ -64,25 +67,20 @@ class MKLDNNSumFwd { mkldnn::sum::primitive_desc fwd_pd; MKLDNNSumFwd(const std::vector &scales, - const std::vector &data_md) - : fwd_pd(scales, data_md) { - data_.resize(data_md.size()); + const std::vector &data_md) + : fwd_pd(scales, data_md, CpuEngine::Get()->get_engine()) { + fwd_ = std::make_shared(fwd_pd); } - void SetNewMem(const std::vector &in_data, const mkldnn::memory &output); - const mkldnn::sum &GetFwd() const { return *fwd_; } private: std::shared_ptr fwd_; - std::vector> data_; - std::vector data_mem_; - std::shared_ptr out_; }; static MKLDNNSumFwd &GetSumForward( const std::vector &scales, const std::vector &in_data, - const std::vector &data_md) { + const std::vector &data_md) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; #else @@ -99,35 +97,12 @@ static MKLDNNSumFwd &GetSumForward( return it->second; } -void MKLDNNSumFwd::SetNewMem(const std::vector &in_data, - const mkldnn::memory &output) { - auto num_inputs = data_.size(); - CHECK_EQ(in_data.size(), num_inputs); - for (index_t i = 0; i < static_cast(num_inputs); ++i) { - if (this->data_[i] == nullptr) { - this->data_[i] = std::shared_ptr( - new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle())); - this->data_mem_.push_back(*this->data_[i]); - } else { - this->data_[i]->set_data_handle(in_data[i]->get_data_handle()); - } - } - if (this->out_ == nullptr) - this->out_ = std::shared_ptr( - new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle())); - else - this->out_->set_data_handle(output.get_data_handle()); - - if (this->fwd_ == nullptr) - this->fwd_.reset(new mkldnn::sum(fwd_pd, this->data_mem_, *this->out_)); -} - void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, const OpReqType &req, const NDArray &out_data) { TmpMemMgr::Get()->Init(ctx.requested[0]); - auto num_inputs = inputs.size(); - std::vector data_md; + const int num_inputs = inputs.size(); + std::vector data_md; std::vector data_mem; std::vector scales(num_inputs, 1); std::vector in_bufs(num_inputs); @@ -135,7 +110,7 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, data_md.reserve(num_inputs); data_mem.reserve(num_inputs); - for (index_t i = 0; i < static_cast(num_inputs); ++i) { + for (int i = 0; i < num_inputs; ++i) { const mkldnn::memory *in_mem; if (inputs[i].IsMKLDNNData() && inputs[i].IsView()) { in_bufs[i] = inputs[i].Reorder2Default(); @@ -144,22 +119,26 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, in_bufs[i] = inputs[i]; in_mem = inputs[i].GetMKLDNNData(); } - mkldnn::memory::primitive_desc tmp_pd = in_mem->get_primitive_desc(); - data_md.push_back(tmp_pd); + mkldnn::memory::desc tmp_md = in_mem->get_desc(); + data_md.push_back(tmp_md); data_mem.push_back(in_mem); } MKLDNNSumFwd &fwd = GetSumForward(scales, in_bufs, data_md); mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data, - fwd.fwd_pd.dst_primitive_desc(), + fwd.fwd_pd.dst_desc(), req, &in_bufs[0]); - fwd.SetNewMem(data_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + mkldnn_args_map_t net_args; + net_args.insert({MKLDNN_ARG_DST, *out_mem.second}); + for (int i = 0; i < num_inputs; ++i) { + net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *data_mem[i]}); + } + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); CommitOutput(out_data, out_mem); MKLDNNStream::Get()->Submit(); } +#endif } // namespace op } // namespace mxnet -#endif diff --git a/src/operator/nn/mkldnn/mkldnn_transpose.cc b/src/operator/nn/mkldnn/mkldnn_transpose.cc index 48444feedcec..2ec38d586552 100644 --- a/src/operator/nn/mkldnn/mkldnn_transpose.cc +++ b/src/operator/nn/mkldnn/mkldnn_transpose.cc @@ -45,9 +45,10 @@ bool SupportMKLDNNTranspose(const TransposeParam& param, typedef ParamOpSign MKLDNNTransposeSignature; class MKLDNNTransposeForward { + public: std::shared_ptr data_; std::shared_ptr out_; - std::shared_ptr dst_pd_; + std::shared_ptr dst_md_; std::shared_ptr transpose_; public: @@ -67,38 +68,23 @@ class MKLDNNTransposeForward { auto engine = CpuEngine::Get()->get_engine(); auto in_mem = data.GetMKLDNNData(); - auto src_pd = in_mem->get_primitive_desc(); - data_ = std::make_shared(src_pd, nullptr); - - // destination - // Not all formats are well defined with a certain name in MKL-DNN. - // For example, transpose(NCHW, (0, 2, 1, 3)) -> NHCW, which is not explicitly defined in - // MKL-DNN. To support general transposing, we need create destination format from scratch. - mkldnn_memory_desc_t dst_fmt; - dst_fmt.primitive_kind = mkldnn_memory; - dst_fmt.ndims = data_ndim; - dst_fmt.data_type = mkldnn_f32; - dst_fmt.format = mkldnn_blocked; - - for (int i = 0; i < data_ndim; i++) - dst_fmt.dims[i] = shape[i]; + auto src_md = in_mem->get_desc(); + data_ = std::make_shared(src_md, engine, nullptr); + mkldnn_dims_t strides; + mkldnn_dims_t sh; unsigned int total_stride = 1; for (int i = data_ndim - 1; i >= 0; i--) { - dst_fmt.layout_desc.blocking.padding_dims[i] = shape[i]; - dst_fmt.layout_desc.blocking.block_dims[i] = 1; - dst_fmt.layout_desc.blocking.offset_padding_to_data[i]= 0; - // strides[0]: stride between the first elements of adjacent blocks. - dst_fmt.layout_desc.blocking.strides[0][axes[i]] = total_stride; - // strides[1]: strides between elements in the same block. - dst_fmt.layout_desc.blocking.strides[1][axes[i]] = 1; - + sh[i] = shape[i]; + strides[axes[i]] = total_stride; total_stride *= shape[axes[i]]; } - dst_fmt.layout_desc.blocking.offset_padding = 0; - dst_pd_ = std::make_shared(dst_fmt, engine); - out_ = std::make_shared(*dst_pd_, nullptr); + mkldnn_memory_desc_t dst_fmt; + mkldnn_memory_desc_init_by_strides(&dst_fmt, data_ndim, sh, mkldnn_f32, strides); + + dst_md_ = std::make_shared(dst_fmt); + out_ = std::make_shared(*dst_md_, engine, nullptr); transpose_ = std::make_shared(*data_, *out_); } @@ -121,6 +107,14 @@ class MKLDNNTransposeForward { const mkldnn::reorder &GetFwd() const { return *transpose_; } + + void Execute() const { + auto stream = MKLDNNStream::Get(); + mkldnn_args_map_t net_args; + net_args.insert({{MKLDNN_ARG_FROM, *(data_)}, {MKLDNN_ARG_TO, *(out_)}}); + stream->RegisterPrimArgs(*transpose_, net_args); + stream->Submit(); + } }; static MKLDNNTransposeForward &GetTransposeForward(const TransposeParam& param, @@ -150,13 +144,11 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, const NDArray &output) { const TransposeParam& param = nnvm::get(attrs.parsed); - auto stream = MKLDNNStream::Get(); auto fwd = GetTransposeForward(param, data); - fwd.SetNewMem(data, output); - stream->RegisterPrim(fwd.GetFwd()); - stream->Submit(); + fwd.Execute(); } } // namespace op } // namespace mxnet #endif + diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 5290c09ec00d..c23a5a852dcb 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -528,15 +528,44 @@ class OpSignature { #if MXNET_USE_MKLDNN == 1 void AddSign(const mkldnn::memory &mem) { - auto desc = mem.get_primitive_desc().desc(); - hash = hash * 2 + desc.data.format; - eles.push_back(desc.data.format); + auto desc = mem.get_desc(); + hash = hash * 2 + desc.data.format_kind; + eles.push_back(desc.data.format_kind); hash = hash * 2 + desc.data.data_type; eles.push_back(desc.data.data_type); for (int i = 0; i < desc.data.ndims; i++) { hash = hash * 2 + desc.data.dims[i]; eles.push_back(desc.data.dims[i]); } + switch (desc.data.format_kind) { + case mkldnn_blocked: + hash = hash * 2 + desc.data.ndims; + eles.push_back(desc.data.ndims); + for (int i = 0; i < desc.data.ndims; i++) { + hash = hash * 2 + desc.data.format_desc.blocking.strides[i]; + eles.push_back(desc.data.format_desc.blocking.strides[i]); + } + hash = hash * 2 + desc.data.format_desc.blocking.inner_nblks; + eles.push_back(desc.data.format_desc.blocking.inner_nblks); + for (int i = 0; i < desc.data.format_desc.blocking.inner_nblks; i++) { + hash = hash * 2 + desc.data.format_desc.blocking.inner_blks[i]; + hash = hash * 2 + desc.data.format_desc.blocking.inner_idxs[i]; + eles.push_back(desc.data.format_desc.blocking.inner_blks[i]); + eles.push_back(desc.data.format_desc.blocking.inner_idxs[i]); + } + break; + case mkldnn_format_kind_wino: + hash = hash * 2 + desc.data.format_desc.wino_desc.wino_format; + eles.push_back(desc.data.format_desc.wino_desc.wino_format); + break; + case mkldnn_format_kind_rnn_packed: + hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.format; + eles.push_back(desc.data.format_desc.rnn_packed_desc.format); + break; + default: + // nothing need to add + break; + } } #endif diff --git a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h index 27fa070afbe0..7ad7aeb7f757 100644 --- a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h @@ -48,8 +48,8 @@ class SgMKLDNNDequantizeOperator { DequantizeParam param_; float cached_data_min_{0.f}; float cached_data_max_{0.f}; - std::shared_ptr i_mem_; - std::shared_ptr o_mem_; + mkldnn::memory::desc o_desc_; + mkldnn_args_map_t args_; std::shared_ptr fwd_pd_; }; @@ -79,37 +79,30 @@ void SgMKLDNNDequantizeOperator::Forward(const OpContext &ctx, const std::vector LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as output type"; } float scale = real_range / quantized_range; - primitive_attr attr; + mkldnn::primitive_attr attr; const int mask = 0; std::vector scales = {scale}; attr.set_output_scales(mask, scales); - attr.set_int_output_round_mode(round_nearest); mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); - auto i_mpd = i_mem->get_primitive_desc(); - auto i_desc = i_mpd.desc(); + auto i_desc = i_mem->get_desc(); size_t i_ndim = in_buffer.shape().ndim(); - mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); - for (size_t i = 0; i < i_ndim; i++) { - i_dims[i] = static_cast(in_buffer.shape()[i]); - } - mkldnn::memory::format o_fmt = static_cast(i_desc.data.format); - if (o_fmt == mkldnn::memory::format::nhwc) { - // For 4d tensor, nchw is the default format - o_fmt = mkldnn::memory::format::nchw; + if (i_ndim == 4) { + mkldnn::memory::format_tag o_fmt = mkldnn::memory::format_tag::nchw; + mkldnn::memory::dims o_dims(i_desc.data.dims, i_desc.data.dims + i_desc.data.ndims); + o_desc_ = mkldnn::memory::desc(o_dims, get_mkldnn_type(), o_fmt); + } else { + o_desc_ = i_desc; + o_desc_.data.data_type = get_mkldnn_type_t(); } - auto o_desc = - mkldnn::memory::desc(i_dims, (mkldnn::memory::data_type)data_type_enum::type, o_fmt); - auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); - auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); - i_mem_ = std::make_shared(i_mpd, nullptr); - o_mem_ = std::make_shared(o_mpd, nullptr); - fwd_pd_ = std::make_shared(reorder_pd, *i_mem_, *o_mem_); + auto reorder_pd = + mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc_, attr); + fwd_pd_ = std::make_shared(reorder_pd); initialized_ = true; } - auto o_mem = CreateMKLDNNMem(outputs[0], o_mem_->get_primitive_desc(), req[0]); - i_mem_->set_data_handle(i_mem->get_data_handle()); - o_mem_->set_data_handle(o_mem.second->get_data_handle()); - MKLDNNStream::Get()->RegisterPrim(*fwd_pd_); + auto o_mem = CreateMKLDNNMem(outputs[0], o_desc_, req[0]); + args_[MKLDNN_ARG_FROM] = *i_mem; + args_[MKLDNN_ARG_TO] = *o_mem.second; + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_pd_, args_); CommitOutput(outputs[0], o_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h index 7a00f621d452..07e2820f5c84 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h @@ -35,11 +35,11 @@ namespace mxnet { namespace op { -template +template static void MKLDNNQuantizeComputeKer(const std::vector& inputs, const std::vector& outputs, const QuantizeParam& param, - const std::vector &req) { + const std::vector& req) { using namespace mshadow; using namespace mxnet_op; using red::limits::MaxValue; @@ -60,38 +60,30 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type"; } float scale = quantized_range / real_range; - primitive_attr attr; + mkldnn::primitive_attr attr; const int mask = 0; std::vector scales = {scale}; attr.set_output_scales(mask, scales); - attr.set_int_output_round_mode(round_nearest); mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); - NDArray in_buffer = inputs[0]; - if (inputs[0].IsView() && inputs[0].IsMKLDNNData()) - in_buffer = inputs[0].Reorder2Default(); + if (inputs[0].IsView() && inputs[0].IsMKLDNNData()) in_buffer = inputs[0].Reorder2Default(); auto i_mem = in_buffer.GetMKLDNNData(); - auto i_mpd = i_mem->get_primitive_desc(); - auto i_desc = i_mpd.desc(); - mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); - if (i_fmt == mkldnn::memory::format::nchw || - i_fmt == mkldnn::memory::format::nChw8c || - i_fmt == mkldnn_nChw16c) { - i_fmt = mkldnn::memory::format::nhwc; - } + auto i_desc = i_mem->get_desc(); size_t i_ndim = in_buffer.shape().ndim(); - mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); - for (size_t i = 0; i < i_ndim; i++) { - i_dims[i] = static_cast(in_buffer.shape()[i]); + mkldnn::memory::desc o_desc; + if (i_ndim == 4) { + mkldnn::memory::format_tag o_fmt = mkldnn::memory::format_tag::nhwc; + mkldnn::memory::dims o_dims(i_desc.data.dims, i_desc.data.dims + i_desc.data.ndims); + o_desc = mkldnn::memory::desc(o_dims, get_mkldnn_type(), o_fmt); + } else { + o_desc = i_desc; + o_desc.data.data_type = get_mkldnn_type_t(); } - auto o_desc = mkldnn::memory::desc(i_dims, - (mkldnn::memory::data_type)data_type_enum::type, - i_fmt); - auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); - auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); - auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second)); + auto reorder_pd = mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc, attr); + auto o_mem = CreateMKLDNNMem(outputs[0], o_desc, req[0]); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(reorder_pd), {{MKLDNN_ARG_FROM, *i_mem}, {MKLDNN_ARG_TO, *o_mem.second}}); CommitOutput(outputs[0], o_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index 7cdce8e32bc8..6e10efa99f32 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -47,8 +47,8 @@ class SgMKLDNNQuantizeOperator { QuantizeV2Param param_; float cached_data_min_{0.f}; float cached_data_max_{0.f}; - std::shared_ptr i_mem_; - std::shared_ptr o_mem_; + mkldnn::memory::desc o_desc_; + mkldnn_args_map_t args_; std::shared_ptr fwd_pd_; }; @@ -127,36 +127,30 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext &ctx, const std::vector scales = {scale}; attr.set_output_scales(mask, scales); - attr.set_int_output_round_mode(round_nearest); mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); - auto i_mpd = i_mem->get_primitive_desc(); - auto i_desc = i_mpd.desc(); - mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); - if (i_fmt == mkldnn::memory::format::nchw || i_fmt == mkldnn::memory::format::nChw8c || - i_fmt == mkldnn_nChw16c) { - i_fmt = mkldnn::memory::format::nhwc; - } + auto i_desc = i_mem->get_desc(); size_t i_ndim = in_buffer.shape().ndim(); - mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); - for (size_t i = 0; i < i_ndim; i++) { - i_dims[i] = static_cast(in_buffer.shape()[i]); + if (i_ndim == 4) { + mkldnn::memory::format_tag o_fmt = mkldnn::memory::format_tag::nhwc; + mkldnn::memory::dims o_dims(i_desc.data.dims, i_desc.data.dims + i_desc.data.ndims); + o_desc_ = mkldnn::memory::desc(o_dims, get_mkldnn_type(out_type), o_fmt); + } else { + o_desc_ = i_desc; + o_desc_.data.data_type = get_mkldnn_type_t(out_type); } - auto o_desc = mkldnn::memory::desc(i_dims, get_mkldnn_type(out_type), i_fmt); - auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); - auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); - i_mem_ = std::make_shared(i_mpd, nullptr); - o_mem_ = std::make_shared(o_mpd, nullptr); - fwd_pd_ = std::make_shared(reorder_pd, *i_mem_, *o_mem_); + auto reorder_pd = + mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc_, attr); + fwd_pd_ = std::make_shared(reorder_pd); initalized_ = true; } - auto o_mem = CreateMKLDNNMem(outputs[0], o_mem_->get_primitive_desc(), req[0]); - i_mem_->set_data_handle(i_mem->get_data_handle()); - o_mem_->set_data_handle(o_mem.second->get_data_handle()); - MKLDNNStream::Get()->RegisterPrim(*fwd_pd_); + auto o_mem = CreateMKLDNNMem(outputs[0], o_desc_, req[0]); + args_[MKLDNN_ARG_FROM] = *i_mem; + args_[MKLDNN_ARG_TO] = *o_mem.second; + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_pd_, args_); CommitOutput(outputs[0], o_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc index 9dd86bd4dd4c..4723ea41bf2f 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc @@ -44,21 +44,22 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const // reorder if data type = uint8 if (in_data[quantized_batchnorm::kData].dtype() == mshadow::kUint8) { - auto u8_pd = data_mem->get_primitive_desc(); - auto u8_md = u8_pd.desc(); - mkldnn::memory::desc s8_md( - mkldnn::memory::dims(u8_md.data.dims, u8_md.data.dims + u8_md.data.ndims), - mkldnn::memory::data_type::s8, static_cast(u8_md.data.format)); - auto s8_pd = mkldnn::memory::primitive_desc(s8_md, CpuEngine::Get()->get_engine()); - auto data_reorder_mem = TmpMemMgr::Get()->Alloc(s8_pd); + auto u8_md = data_mem->get_desc(); + auto s8_md = u8_md; + s8_md.data.data_type = static_cast(mkldnn::memory::data_type::s8); + auto data_reorder_mem = TmpMemMgr::Get()->Alloc(s8_md); std::vector reorder_scale; reorder_scale = {static_cast(kInt8Range) / kUint8Range}; - primitive_attr reorder_attr; - reorder_attr.set_int_output_round_mode(round_mode::round_nearest); + mkldnn::primitive_attr reorder_attr; reorder_attr.set_output_scales(0, reorder_scale); - const auto reorder_pd = mkldnn::reorder::primitive_desc(u8_pd, s8_pd, reorder_attr); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *data_mem, *data_reorder_mem)); + mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine(); + const auto reorder_pd = + mkldnn::reorder::primitive_desc(cpu_engine, u8_md, cpu_engine, s8_md, reorder_attr); + mkldnn_args_map_t reorder_args; + reorder_args[MKLDNN_ARG_SRC] = *data_mem; + reorder_args[MKLDNN_ARG_DST] = *data_reorder_mem; + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(reorder_pd), reorder_args); data_mem = data_reorder_mem; } const size_t channelAxis = static_cast( @@ -79,10 +80,11 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const } const float max_abs_output = std::max(std::abs(*min_output_ptr), std::abs(*max_output_ptr)); - unsigned flags = mkldnn::use_global_stats | mkldnn::use_scale_shift; + mkldnn::normalization_flags flags = + mkldnn::normalization_flags::use_global_stats | mkldnn::normalization_flags::use_scale_shift; auto &fwd = GetBNForward(param, ctx, data_mem, flags); const mkldnn::memory &weight_mem = fwd.GetWeight(); - CHECK_EQ(weight_mem.get_primitive_desc().get_size(), channel_count * sizeof(float) * 2); + CHECK_EQ(weight_mem.get_desc().get_size(), channel_count * sizeof(float) * 2); float *weight_buf = reinterpret_cast(weight_mem.get_data_handle()); float *gamma_ptr = in_data[quantized_batchnorm::kGamma].data().dptr(); @@ -94,9 +96,8 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const float *moving_var_ptr = moving_var.data().dptr(); // rescale gamma and beta, to make mean=0 and var=1 - auto rescaled_mean_mem = - TmpMemMgr::Get()->Alloc(moving_mean.GetMKLDNNData()->get_primitive_desc()); - auto rescaled_var_mem = TmpMemMgr::Get()->Alloc(moving_var.GetMKLDNNData()->get_primitive_desc()); + auto rescaled_mean_mem = TmpMemMgr::Get()->Alloc(moving_mean.GetMKLDNNData()->get_desc()); + auto rescaled_var_mem = TmpMemMgr::Get()->Alloc(moving_var.GetMKLDNNData()->get_desc()); float *rescaled_mean_ptr = reinterpret_cast(rescaled_mean_mem->get_data_handle()); float *rescaled_var_ptr = reinterpret_cast(rescaled_var_mem->get_data_handle()); @@ -111,11 +112,16 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const rescaled_var_ptr[channel] = 1.0f; } - auto out_mem = CreateMKLDNNMem(outputs[batchnorm::kOut], - fwd.GetPd().dst_primitive_desc(), req[batchnorm::kOut], &data); - fwd.SetDataHandle(data_mem, rescaled_mean_mem, rescaled_var_mem, out_mem.second); + const NDArray &out = outputs[batchnorm::kOut]; + auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); + mkldnn_args_map_t net_args; + net_args[MKLDNN_ARG_SRC] = *data_mem; + net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem; + net_args[MKLDNN_ARG_DST] = *out_mem; + net_args[MKLDNN_ARG_MEAN] = *rescaled_mean_mem; + net_args[MKLDNN_ARG_VARIANCE] = *rescaled_var_mem; - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc index 2a4c6d612e65..619e8bf1e1fb 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc @@ -60,7 +60,7 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpC out_data[quantized_concat_enum::kMin].data().dptr()[0] = output_neg_min; out_data[quantized_concat_enum::kMax].data().dptr()[0] = output_pos_max; auto out_scale = GetScale(out_data[quantized_concat_enum::kOut], output_neg_min, output_pos_max); - std::vector data_md; + std::vector data_md; std::vector data_mem; // new_data_mem is for auto-free new created mkldnn memory std::vector> new_data_mem; @@ -71,36 +71,37 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpC CHECK(in_data[i].dtype() == out_dtype); auto mem = in_data[i].GetMKLDNNData(); data_mem.push_back(mem); - data_md.push_back(mem->get_primitive_desc()); + data_md.push_back(mem->get_desc()); } else { auto mem = in_data[i].GetMKLDNNData(); - auto pd = mem->get_primitive_desc(); + auto mem_desc = mem->get_desc(); if (in_data[i].dtype() != out_dtype) { - auto mem_desc = pd.desc(); - mkldnn::memory::desc new_md( - mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims), - get_mkldnn_type(out_dtype), static_cast(mem_desc.data.format)); - pd = mkldnn::memory::primitive_desc(new_md, CpuEngine::Get()->get_engine()); + mem_desc.data.data_type = static_cast(get_mkldnn_type(out_dtype)); } - const auto rescaled_mem = std::make_shared(pd); + const auto rescaled_mem = + std::make_shared(mem_desc, CpuEngine::Get()->get_engine()); new_data_mem.push_back(rescaled_mem); std::vector reorder_scale = {out_scale / i_scale}; - primitive_attr reorder_attr; - reorder_attr.set_int_output_round_mode(round_mode::round_nearest); + mkldnn::primitive_attr reorder_attr; reorder_attr.set_output_scales(0, reorder_scale); - const auto reorder_pd = - mkldnn::reorder::primitive_desc(mem->get_primitive_desc(), pd, reorder_attr); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *mem, *rescaled_mem)); + const auto reorder_pd = mkldnn::reorder::primitive_desc(*mem, *rescaled_mem, reorder_attr); + mkldnn_args_map_t reorder_args; + reorder_args[MKLDNN_ARG_SRC] = *mem; + reorder_args[MKLDNN_ARG_DST] = *rescaled_mem; + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(reorder_pd), reorder_args); data_mem.push_back(rescaled_mem.get()); - data_md.push_back(pd); + data_md.push_back(mem_desc); } } MKLDNNConcatFwd& fwd = GetConcatForward(param_.dim, in_data, data_md); - mxnet::mkldnn_output_t out_mem = - CreateMKLDNNMem(out_data[quantized_concat_enum::kOut], fwd.fwd_pd.dst_primitive_desc(), - req[concat_enum::kOut]); - fwd.SetNewMem(data_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data[quantized_concat_enum::kOut], + fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]); + mkldnn_args_map_t net_args; + net_args[MKLDNN_ARG_DST] = *out_mem.second; + for (int i = 0; i < param_.num_args; i++) { + net_args[MKLDNN_ARG_MULTIPLE_SRC + i] = *data_mem[i]; + } + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); CommitOutput(out_data[concat_enum::kOut], out_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc index f81071704762..6ac2250281d3 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc @@ -43,11 +43,12 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs, TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); NDArray weight = in_data[conv::kWeight]; ConvolutionParam param = nnvm::get(attrs.parsed); - auto &fwd = GetConvFwd( - param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], - param.no_bias ? nullptr : &in_data[conv::kBias], - out_data[conv::kOut]); - auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc()); + MKLDNNConvFullParam full_param; + full_param.conv_param = param; + full_param.mkldnn_param.Init(std::unordered_map()); + auto &fwd = GetConvFwd(full_param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], + param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); + auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.GetPd().src_desc()); const mkldnn::memory *weight_mem; // For inference, we want to reorder the weight array so we don't need to // reorder data every time. @@ -55,20 +56,23 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs, // We also need to modify the layout on the original weight array. // Don't switch below sequence because naive engine will executes // pushAsync synchronously. - weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc()); - weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group); + weight.MKLDNNDataReorderAsync(fwd.GetPd().weights_desc()); + weight_mem = GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group); } else { weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc()); } - auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(), + auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.GetPd().dst_desc(), req[conv::kOut]); - const mkldnn::memory *bias_mem = nullptr; - if (!param.no_bias) - bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc()); - fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); - + mkldnn_args_map_t net_args; + if (!param.no_bias) { + const mkldnn::memory *bias_mem = + in_data[conv::kBias].GetMKLDNNDataReorder(fwd.GetPd().bias_desc()); + net_args.insert({MKLDNN_ARG_BIAS, *bias_mem}); + } + net_args.insert({MKLDNN_ARG_SRC, *data_mem}); + net_args.insert({MKLDNN_ARG_WEIGHTS, *weight_mem}); + net_args.insert({MKLDNN_ARG_DST, *out_mem.second}); + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); CommitOutput(out_data[conv::kOut], out_mem); MKLDNNStream::Get()->Submit(); Stream *s = ctx.get_stream(); diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc index 2be6b2baca63..2078ac4fead8 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc @@ -73,17 +73,17 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons // output default set as int32 float output_data_range = kInt32Range; - auto output_data_type = mkldnn::memory::s32; + auto output_data_type = mkldnn::memory::data_type::s32; // dataA && dataB are uint8 if (out_data[quantized_elemwise_add_enum::kOut].dtype() == mshadow::kInt8) { output_data_range = kInt8Range; - output_data_type = mkldnn::memory::s8; + output_data_type = mkldnn::memory::data_type::s8; } else if (out_data[quantized_elemwise_add_enum::kOut].dtype() == mshadow::kUint8) { output_data_range = kUint8Range; - output_data_type = mkldnn::memory::u8; + output_data_type = mkldnn::memory::data_type::u8; } else { output_data_range = kInt32Range; - output_data_type = mkldnn::memory::s32; + output_data_type = mkldnn::memory::data_type::s32; } float output_min = 0; @@ -100,12 +100,13 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons // 2: scale 0 for dataA, scale 1 for data B const int scales_num = 2; std::vector scales(scales_num, 1); + auto engine = CpuEngine::Get()->get_engine(); if (in_data[quantized_elemwise_add_enum::kDataA].dtype() != in_data[quantized_elemwise_add_enum::kDataB].dtype()) { - auto s8_pd = (is_dataA_int8 == true) - ? dataA_mem->get_primitive_desc() - : dataB_mem->get_primitive_desc(); - rescaled_mem = TmpMemMgr::Get()->Alloc(s8_pd); + auto s8_desc = (is_dataA_int8 == true) + ? dataA_mem->get_desc() + : dataB_mem->get_desc(); + rescaled_mem = TmpMemMgr::Get()->Alloc(s8_desc); float u8_reorder_scale = 0; if (params.max_calib_range.has_value() && params.min_calib_range.has_value()) { if (is_dataA_int8 == true) { @@ -130,14 +131,16 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons } } std::vector reorder_scale = {u8_reorder_scale}; - primitive_attr reorder_attr; - reorder_attr.set_int_output_round_mode(round_mode::round_nearest); + mkldnn::primitive_attr reorder_attr; reorder_attr.set_output_scales(0, reorder_scale); auto u8_mem = (is_dataA_int8 == true) ? dataB_mem : dataA_mem; - const auto reorder_pd = mkldnn::reorder::primitive_desc(u8_mem->get_primitive_desc(), - s8_pd, + const auto reorder_pd = mkldnn::reorder::primitive_desc(engine, + u8_mem->get_desc(), + engine, + s8_desc, reorder_attr); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *u8_mem, *rescaled_mem)); + mkldnn_args_map_t args({{MKLDNN_ARG_FROM, *u8_mem }, {MKLDNN_ARG_TO, *rescaled_mem}}); + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(reorder_pd), args); if (is_dataA_int8 == true) { dataB_mem = rescaled_mem; @@ -155,27 +158,24 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons } } - std::vector in_prims; - std::vector in_pds; - in_prims.push_back(*dataA_mem); - in_prims.push_back(*dataB_mem); - in_pds.push_back(dataA_mem->get_primitive_desc()); - in_pds.push_back(dataB_mem->get_primitive_desc()); - size_t i_ndim = in_data[quantized_elemwise_add_enum::kDataA].shape().ndim(); - mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); - for (size_t i = 0; i < i_ndim; i++) { - i_dims[i] = static_cast(in_data[quantized_elemwise_add_enum::kDataA].shape()[i]); - } - mkldnn::memory::format i_fmt = static_cast( - in_pds[quantized_elemwise_add_enum::kDataA].desc().data.format); - auto output_desc = mkldnn::memory::desc(i_dims, output_data_type, i_fmt); - mkldnn::sum::primitive_desc pdesc(output_desc, scales, in_pds); + std::vector in_desc; + in_desc.push_back(dataA_mem->get_desc()); + in_desc.push_back(dataB_mem->get_desc()); + const auto in_shape = in_data[quantized_elemwise_add_enum::kDataA].shape(); + mkldnn::memory::dims i_dims(in_shape.begin(), in_shape.end()); + auto output_desc = mkldnn::memory::desc(i_dims, + output_data_type, + mkldnn::memory::format_tag::any); + mkldnn::sum::primitive_desc pdesc(output_desc, scales, in_desc, engine); auto mem = CreateMKLDNNMem(out_data[quantized_elemwise_add_enum::kOut], - pdesc.dst_primitive_desc(), + pdesc.dst_desc(), req[0], &in_data[0]); + mkldnn_args_map_t args({{MKLDNN_ARG_MULTIPLE_SRC, *dataA_mem}, + {MKLDNN_ARG_MULTIPLE_SRC + 1, *dataB_mem}, + {MKLDNN_ARG_DST, *mem.second}}); MKLDNNStream *stream = MKLDNNStream::Get(); - stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *mem.second)); + stream->RegisterPrimArgs(mkldnn::sum(pdesc), args); CommitOutput(out_data[quantized_elemwise_add_enum::kOut], mem); stream->Submit(); diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index f451ff3c977d..3e21564b3b04 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -94,28 +94,35 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs, auto &fwd = GetFCFwd(param, is_train, data, weight, param.no_bias ? nullptr : &quantized_bias, out_md); - auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc()); + auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_desc()); const mkldnn::memory *weight_mem = nullptr; if (weight.IsDefaultData()) { // We also need to modify the layout on the original weight array. // Don't switch below sequence because naive engine will executes // pushAsync synchronously. - weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc()); - weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1); + weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_desc()); + weight_mem = GetWeights(weight, fwd.fwd_pd.weights_desc(), 1); } else { weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc()); + CHECK(weight_mem->get_desc() == fwd.fwd_pd.weights_desc()); } - auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_primitive_desc(), + auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_desc(), req[fullc::kOut]); - const mkldnn::memory *bias_mem = nullptr; - if (!param.no_bias) - bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc()); - fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + mkldnn_args_map_t args = { + {MKLDNN_ARG_SRC, *data_mem}, + {MKLDNN_ARG_WEIGHTS, *weight_mem}, + {MKLDNN_ARG_DST, *out_mem.second}, + }; + + const mkldnn::memory *bias_mem = nullptr; + if (!param.no_bias) { + bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_desc()); + args[MKLDNN_ARG_BIAS] = *bias_mem; + } + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), args); CommitOutput(out_data[fullc::kOut], out_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc index 07e14412618d..190dfed23197 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc @@ -38,9 +38,7 @@ static void MKLDNNQuantizedPoolingForward(const nnvm::NodeAttrs& attrs, const Op || in_data[0].dtype() == mshadow::kInt8) << "mkldnn_quantized_pooling op only supports uint8 and int8 as input type"; const PoolingParam& param = nnvm::get(attrs.parsed); - auto fwd = GetPoolingFwd(param, ctx.is_train, in_data[0], out_data[0]); - fwd.SetNewMem(in_data[0], out_data[0], req[0]); - fwd.Execute(out_data[0]); + MKLDNNPoolingCompute(ctx, param, in_data[0], req[0], out_data[0], nullptr); out_data[1].data().dptr()[0] = in_data[1].data().dptr()[0]; out_data[2].data().dptr()[0] = in_data[2].data().dptr()[0]; } diff --git a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h index 03d9b9067b57..a80b855dceb1 100644 --- a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h @@ -71,11 +71,10 @@ static void MKLDNNRequantizeForwardKer(const nnvm::NodeAttrs& attrs, float second_scale = second_quantized_range / second_real_range; float scale = first_scale * second_scale; - primitive_attr attr; + mkldnn::primitive_attr attr; const int mask = 0; std::vector scales = {scale}; attr.set_output_scales(mask, scales); - attr.set_int_output_round_mode(round_nearest); mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); NDArray in_buffer = inputs[0]; @@ -83,20 +82,13 @@ static void MKLDNNRequantizeForwardKer(const nnvm::NodeAttrs& attrs, in_buffer = inputs[0].Reorder2Default(); auto i_mem = in_buffer.GetMKLDNNData(); - auto i_mpd = i_mem->get_primitive_desc(); - auto i_desc = i_mpd.desc(); - mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); - mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_dim); - for (size_t i = 0; i < i_dim; i++) { - i_dims[i] = static_cast(in_buffer.shape()[i]); - } - auto o_desc = mkldnn::memory::desc(i_dims, - (mkldnn::memory::data_type)data_type_enum::type, - i_fmt); - auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); - auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); - auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second)); + auto i_desc = i_mem->get_desc(); + auto o_desc = i_desc; + o_desc.data.data_type = get_mkldnn_type_t(); + auto reorder_pd = mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc, attr); + auto o_mem = CreateMKLDNNMem(outputs[0], o_desc, req[0]); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(reorder_pd), {{MKLDNN_ARG_FROM, *i_mem}, {MKLDNN_ARG_TO, *o_mem.second}}); CommitOutput(outputs[0], o_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/quantization/quantized_pooling.cc b/src/operator/quantization/quantized_pooling.cc index 1839e2a29d77..eeb2ac4de26c 100644 --- a/src/operator/quantization/quantized_pooling.cc +++ b/src/operator/quantization/quantized_pooling.cc @@ -98,7 +98,7 @@ bool QuantizedPoolingType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_type->size(), 3U); CHECK_EQ(out_type->size(), 3U); if (param.pool_type == pool_enum::kMaxPooling || param.pool_type == pool_enum::kAvgPooling) { -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 1 TYPE_ASSIGN_CHECK(*out_type, 0, (*in_type)[0]); #else TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8); diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index d5fd351986e3..db2360313aef 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -26,11 +26,6 @@ #ifndef MXNET_OPERATOR_RNN_INL_H_ #define MXNET_OPERATOR_RNN_INL_H_ -#if MXNET_USE_CUDNN == 1 -STATIC_ASSERT_CUDNN_VERSION_GE(7000); -#endif -#define MXNET_USE_CUDNN_GE_7200 MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200 - #include #include #include @@ -46,13 +41,87 @@ STATIC_ASSERT_CUDNN_VERSION_GE(7000); #include "./math_functions-inl.h" #include "./operator_common.h" #include "./rnn_impl.h" -#if MXNET_USE_MKLDNN == 1 -#include "./nn/mkldnn/mkldnn_rnn_impl.h" + +#if MXNET_USE_CUDNN == 1 +STATIC_ASSERT_CUDNN_VERSION_GE(7000); #endif +#define MXNET_USE_CUDNN_GE_7200 MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200 namespace mxnet { namespace op { +namespace rnn_enum { + enum RNNOpInputs {kData, kParams, kState, kStateCell, kSequenceLength}; + enum RNNOpOutputs {kOut, kStateOut, kStateCellOut}; + enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru}; + enum RNNOpResource {kTempSpace, kCuDNNDropoutDescSpace}; +} + +struct RNNParam : public dmlc::Parameter { + uint32_t state_size; + uint32_t num_layers; + bool bidirectional, state_outputs; + int mode; + float p; + int seq_length_, batch_size_, input_size_; + + bool use_sequence_length; + dmlc::optional projection_size; + dmlc::optional lstm_state_clip_min, lstm_state_clip_max; + bool lstm_state_clip_nan; + + DMLC_DECLARE_PARAMETER(RNNParam) { + DMLC_DECLARE_FIELD(state_size) + .describe("size of the state for each layer"); + + DMLC_DECLARE_FIELD(num_layers) + .describe("number of stacked layers"); + + DMLC_DECLARE_FIELD(bidirectional).set_default(false) + .describe("whether to use bidirectional recurrent layers"); + + DMLC_DECLARE_FIELD(mode) + .add_enum("rnn_relu", rnn_enum::kRnnRelu) + .add_enum("rnn_tanh", rnn_enum::kRnnTanh) + .add_enum("lstm", rnn_enum::kLstm) + .add_enum("gru", rnn_enum::kGru) + .describe("the type of RNN to compute"); + + DMLC_DECLARE_FIELD(p).set_default(0.) + .set_range(0, 1) + .describe("drop rate of the dropout on the outputs of each RNN layer, except the last layer."); + + DMLC_DECLARE_FIELD(state_outputs).set_default(false) + .describe("Whether to have the states as symbol outputs."); + + DMLC_DECLARE_FIELD(projection_size) + .set_default(dmlc::optional()) + .describe("size of project size"); + + DMLC_DECLARE_FIELD(lstm_state_clip_min) + .set_default(dmlc::optional()) + .describe("Minimum clip value of LSTM states. This option must be used together with " + "lstm_state_clip_max."); + + DMLC_DECLARE_FIELD(lstm_state_clip_max) + .set_default(dmlc::optional()) + .describe("Maximum clip value of LSTM states. This option must be used together with " + "lstm_state_clip_min."); + + DMLC_DECLARE_FIELD(lstm_state_clip_nan) + .set_default(false) + .describe("Whether to stop NaN from propagating in state by clipping it to min/max. " + "If clipping range is not specified, this option is ignored."); + + DMLC_DECLARE_FIELD(use_sequence_length) + .set_default(false) + .describe( + "If set to true, this layer takes in an extra input parameter " + "`sequence_length` " + "to specify variable length sequence"); + } +}; + inline int GetRnnParamSize(int num_layer, int input_size, int state_size, @@ -86,9 +155,9 @@ inline int GetRnnParamSize(int num_layer, } inline int GetRnnBiasSize(int num_layer, - int state_size, - int direction, - int mode) { + int state_size, + int direction, + int mode) { int size = 2 * state_size * direction * num_layer; switch (mode) { case rnn_enum::kRnnRelu: @@ -104,6 +173,15 @@ inline int GetRnnBiasSize(int num_layer, return size; } +/* + * Calculate the space size of the intermediate results for RNN inference. + * The inference procedure of a fusion RNN operator calculates the outputs + * layer by layer. In one layer calculation, the steps are: + * - wx[1...Ngates] * x[1...T] among all time stamp(sz: TxNxHxNgates) + * - wh[1...Ngates] * h[t] time by time(sz: NxHxNgates) + * - output -> h[t](, c[t] additionally with Lstm) time by time(sz: NxH(x2)) + * - intermediate y[1...T] as next layer's inputs(sz: TxNxHxD) + */ inline size_t GetRNNWorkspaceSize(int seq_length, int batch_size, int hidden_size, @@ -112,15 +190,19 @@ inline size_t GetRNNWorkspaceSize(int seq_length, size_t size = 0; switch (mode) { case rnn_enum::kLstm: - size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2 - + seq_length * batch_size * hidden_size * direction + hidden_size * seq_length * 8; + size = seq_length * batch_size * hidden_size * (4 + direction) + // wx*x + inter-y + batch_size * hidden_size * 6 + // wh*h + h + c + seq_length * hidden_size * 8; // Used in Backward, Δbx, Δbh break; case rnn_enum::kGru: - size = seq_length * batch_size * hidden_size * direction * 4 + batch_size * hidden_size * 8; + // Differs with Lstm, the outputs of three gates are also held in memory + size = seq_length * batch_size * hidden_size * direction * (3 + 1) + // wx*x + inter-y + batch_size * hidden_size * (6 + direction); // wh*h + h + Ngates break; case rnn_enum::kRnnRelu: case rnn_enum::kRnnTanh: - size = seq_length * batch_size * hidden_size * direction * 2 + batch_size * hidden_size * 4; + size = seq_length * batch_size * hidden_size * direction * 2 + // wx*x + inter-y + batch_size * hidden_size * (1 + direction); // h + Ngates break; default: LOG(FATAL) << "unknown RNN mode " << mode; @@ -158,71 +240,6 @@ inline size_t GetRNNReserveSpaceSize(int num_layer, return size; } -struct RNNParam : public dmlc::Parameter { - uint32_t state_size; - uint32_t num_layers; - bool bidirectional, state_outputs; - int mode; - float p; - int seq_length_, batch_size_, input_size_; - - bool use_sequence_length; - dmlc::optional projection_size; - dmlc::optional lstm_state_clip_min, lstm_state_clip_max; - bool lstm_state_clip_nan; - - DMLC_DECLARE_PARAMETER(RNNParam) { - DMLC_DECLARE_FIELD(state_size) - .describe("size of the state for each layer"); - - DMLC_DECLARE_FIELD(num_layers) - .describe("number of stacked layers"); - - DMLC_DECLARE_FIELD(bidirectional).set_default(false) - .describe("whether to use bidirectional recurrent layers"); - - DMLC_DECLARE_FIELD(mode) - .add_enum("rnn_relu", rnn_enum::kRnnRelu) - .add_enum("rnn_tanh", rnn_enum::kRnnTanh) - .add_enum("lstm", rnn_enum::kLstm) - .add_enum("gru", rnn_enum::kGru) - .describe("the type of RNN to compute"); - - DMLC_DECLARE_FIELD(p).set_default(0.) - .set_range(0, 1) - .describe("drop rate of the dropout on the outputs of each RNN layer, except the last layer."); - - DMLC_DECLARE_FIELD(state_outputs).set_default(false) - .describe("Whether to have the states as symbol outputs."); - - DMLC_DECLARE_FIELD(projection_size) - .set_default(dmlc::optional()) - .describe("size of project size"); - - DMLC_DECLARE_FIELD(lstm_state_clip_min) - .set_default(dmlc::optional()) - .describe("Minimum clip value of LSTM states. This option must be used together with " - "lstm_state_clip_max."); - - DMLC_DECLARE_FIELD(lstm_state_clip_max) - .set_default(dmlc::optional()) - .describe("Maximum clip value of LSTM states. This option must be used together with " - "lstm_state_clip_min."); - - DMLC_DECLARE_FIELD(lstm_state_clip_nan) - .set_default(false) - .describe("Whether to stop NaN from propagating in state by clipping it to min/max. " - "If clipping range is not specified, this option is ignored."); - - DMLC_DECLARE_FIELD(use_sequence_length) - .set_default(false) - .describe( - "If set to true, this layer takes in an extra input parameter " - "`sequence_length` " - "to specify variable length sequence"); - } -}; - inline size_t GetNumInputArguments(RNNParam param_) { size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4U : 3U; if (param_.use_sequence_length) num_inputs += 1U; @@ -398,114 +415,94 @@ class RNNOp { public: RNNParam param_; Context ctx_; -#if MXNET_USE_MKLDNN == 1 - std::vector concat_weight_memory; - std::vector concat_iter_memory; - std::vector rnn_forward_prim; - std::vector x_memory; - std::vector hcx_memory; - std::vector wx_memory; - std::vector wh_memory; - std::vector bias_memory; - std::vector y_memory; - std::vector hcy_memory; - size_t weights_version; - bool has_cache; - bool init_mem_; - size_t reserve_mem_size_; - NDArray mem_space_; -#endif + explicit RNNOp(RNNParam param, Context ctx) { this->param_ = param; this->ctx_ = ctx; -#if MXNET_USE_MKLDNN == 1 - init_mem_ = false; - reserve_mem_size_ = 0; -#endif if (ctx_.dev_type == kGPU) { #if MXNET_USE_CUDNN == 1 - init_cudnn_ = false; - dtype_ = mshadow::DataType::kCudnnFlag; - // TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy. - // No tests in place for fp16 RNNs, so leave TensorCore disabled for now. - cudnn_tensor_core_ = false; - // When fp16 RNN tests are introduced, we can enable TensorCore as follows: - // cudnn_tensor_core = - // mshadow::DataType::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore(); - // Defaults - input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet - // RNN Mode - switch (param_.mode) { - case rnn_enum::kRnnRelu: - mode_ = CUDNN_RNN_RELU; - break; - case rnn_enum::kRnnTanh: - mode_ = CUDNN_RNN_TANH; - break; - case rnn_enum::kLstm: - mode_ = CUDNN_LSTM; - break; - case rnn_enum::kGru: - mode_ = CUDNN_GRU; - break; - default: - LOG(FATAL) << "Not implmented"; - } + init_cudnn_ = false; + dtype_ = mshadow::DataType::kCudnnFlag; + // TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy. + // No tests in place for fp16 RNNs, so leave TensorCore disabled for now. + cudnn_tensor_core_ = false; + // When fp16 RNN tests are introduced, we can enable TensorCore as follows: + // cudnn_tensor_core = + // mshadow::DataType::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore(); + // Defaults + input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet + // RNN Mode + switch (param_.mode) { + case rnn_enum::kRnnRelu: + mode_ = CUDNN_RNN_RELU; + break; + case rnn_enum::kRnnTanh: + mode_ = CUDNN_RNN_TANH; + break; + case rnn_enum::kLstm: + mode_ = CUDNN_LSTM; + break; + case rnn_enum::kGru: + mode_ = CUDNN_GRU; + break; + default: + LOG(FATAL) << "Not implmented"; + } #if MXNET_USE_CUDNN_GE_7200 - if (param_.projection_size.has_value()) { - CHECK_EQ(param_.mode, rnn_enum::kLstm) - << "Projection is only supported for LSTM."; - CHECK_GE(param_.state_size, param_.projection_size.value()) - << "State size must be larger than projection size."; - } + if (param_.projection_size.has_value()) { + CHECK_EQ(param_.mode, rnn_enum::kLstm) + << "Projection is only supported for LSTM."; + CHECK_GE(param_.state_size, param_.projection_size.value()) + << "State size must be larger than projection size."; + } #else - CHECK(!param_.projection_size.has_value()) - << "Projection is only supported for LSTM with CuDNN version later than 7.1.1."; + CHECK(!param_.projection_size.has_value()) + << "Projection is only supported for LSTM with CuDNN version later than 7.1.1."; #endif // MXNET_USE_CUDNN_GE_7200 #if MXNET_USE_CUDNN_GE_7200 - if (param_.lstm_state_clip_min.has_value() - || param_.lstm_state_clip_max.has_value()) { - CHECK_EQ(param_.mode, rnn_enum::kLstm) - << "State clipping is only supported for LSTM."; - CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value()) - << "lstm_state_clip_min and lstm_state_clip_max must be specified together."; - CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value()) - << "lstm_state_clip_max must be greater or equal to lstm_state_clip_min"; - } + if (param_.lstm_state_clip_min.has_value() + || param_.lstm_state_clip_max.has_value()) { + CHECK_EQ(param_.mode, rnn_enum::kLstm) + << "State clipping is only supported for LSTM."; + CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value()) + << "lstm_state_clip_min and lstm_state_clip_max must be specified together."; + CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value()) + << "lstm_state_clip_max must be greater or equal to lstm_state_clip_min"; + } #else - CHECK(!param_.lstm_state_clip_min.has_value() - && !param_.lstm_state_clip_max.has_value()) - << "State clipping is only supported for LSTM with CuDNN version later than 7.2.1."; + CHECK(!param_.lstm_state_clip_min.has_value() + && !param_.lstm_state_clip_max.has_value()) + << "State clipping is only supported for LSTM with CuDNN version later than 7.2.1."; #endif // MXNET_USE_CUDNN_GE_7200 - // RNN Direction - direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; - // Create descriptors - CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_)); - - CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_)); - CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_)); - - CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_)); - CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); + // RNN Direction + direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + // Create descriptors + CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_)); + + CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_)); + CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_)); + + CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_)); + CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); #if MXNET_USE_CUDNN_GE_7200 - CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_)); - CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_)); - CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_)); - CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_)); -#endif + CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_)); + CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_)); + CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_)); + CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_)); +#endif // MXNET_USE_CUDNN_GE_7200 #else - if (ctx_.dev_type == kGPU) { - LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment."; - } + if (ctx_.dev_type == kGPU) { + LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment."; + } #endif // MXNET_USE_CUDNN == 1 } @@ -857,7 +854,7 @@ class RNNOp { } DType* work_cpu_space = static_cast(temp_cpu_space_.data().dptr_); - if (ctx.is_train) { + if (ctx.is_train || ctx.need_grad) { // allocate reserve space const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, @@ -891,64 +888,23 @@ class RNNOp { param_.p, param_.mode); } else { -#if MXNET_USE_MKLDNN == 1 - if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1) && param_.mode != rnn_enum::kGru) { - // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one - // will be added to MXNet when we figure out the issue. - int dtype = in_data[rnn_enum::kData].type_flag_; - MKLDNNRNNForwardInference(param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - &concat_weight_memory, - &concat_iter_memory, - &x_memory, - &hcx_memory, - &wx_memory, - &wh_memory, - &bias_memory, - &y_memory, - &hcy_memory, - &rnn_forward_prim, - &has_cache, - dtype, - ctx.is_train, - param_.mode); - } else { -#endif // MXNET_USE_MKLDNN == 1 - // Before integrating MKLDNN GRU fp32 inference - // using below code for keep func being OK - RNNForwardInference(work_cpu_space, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.mode); -#if MXNET_USE_MKLDNN == 1 - } -#endif + RNNForwardInference(work_cpu_space, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); } } } @@ -1489,6 +1445,10 @@ class RNNOp { } #endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) } + // naive private variables used in CPU Context + bool init_space_, temp_init_space_; + size_t reserve_cpu_space_size_, temp_cpu_space_size_; + NDArray reserve_cpu_space_, temp_cpu_space_; #if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) // cuDNN versions up to and including v7.6.4 did not sync a last dgrad kernel back to the main @@ -1538,39 +1498,8 @@ class RNNOp { bool dgrad_sync_event_created_ = false; bool dgrad_sync_needed_ = false; #endif // MXNET_USE_CUDNN - bool init_space_, temp_init_space_; - size_t reserve_cpu_space_size_, temp_cpu_space_size_; - NDArray reserve_cpu_space_, temp_cpu_space_; }; // class RNNOp -static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, - const Context ctx, - const mxnet::ShapeVector &in_shapes, - const std::vector &in_types) { - const RNNParam& param = nnvm::get(attrs.parsed); - OpStatePtr state = OpStatePtr(); - int dtype = in_types[rnn_enum::kData]; - int itype = dtype; - if (param.use_sequence_length) { - size_t seq_len_input_idx = rnn_enum::kSequenceLength; - if (param.mode != rnn_enum::kLstm) { - seq_len_input_idx -= 1; - } - itype = in_types[seq_len_input_idx]; - } - - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - MSHADOW_TYPE_SWITCH(itype, IType, { - if (ctx.dev_type == kGPU) { - state = OpStatePtr::Create>(param, ctx); - } else { - state = OpStatePtr::Create>(param, ctx); - } - }); - }); - return state; -} - template void RNNStatefulCompute(const OpStatePtr& state, const OpContext& ctx, @@ -1582,14 +1511,14 @@ void RNNStatefulCompute(const OpStatePtr& state, // Hacky. This relies on fact that seq-len type is either the last input, // or we aren't using seq-len input and this type should be same as dtype. // Would prefer direct access to RNNParam object here but not sure how to get. - int itype = inputs[inputs.size()-1].type_flag_; + int itype = inputs[inputs.size() - 1].type_flag_; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - MSHADOW_TYPE_SWITCH(itype, IType, { - RNNOp& op = state.get_state>(); - op.Forward(ctx, inputs, req, outputs); - }); + MSHADOW_TYPE_SWITCH(itype, IType, { + RNNOp& op = state.get_state>(); + op.Forward(ctx, inputs, req, outputs); }); + }); } /* @@ -1621,38 +1550,38 @@ void RNNStatefulGradCompute(const OpStatePtr& state, // Hacky. This relies on fact that seq-len type is either the last input, // or we aren't using seq-len input and this type should be same as dtype. // Would prefer direct access to RNNParam object here but not sure how to get. - int itype = outputs[outputs.size()-1].type_flag_; + int itype = outputs[outputs.size() - 1].type_flag_; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - MSHADOW_TYPE_SWITCH(itype, IType, { - RNNOp& op = state.get_state>(); - const RNNParam& param = op.param_; - int index = 5; - if (param.state_outputs) { - out_data.push_back(inputs[index++]); - out_grad.push_back(inputs[index++]); - } - - if (param.mode == rnn_enum::kLstm) { - in_data.push_back(inputs[index++]); - if (param.state_outputs) { - out_data.push_back(inputs[index++]); - out_grad.push_back(inputs[index]); - } - } - - - if (param.use_sequence_length) { - size_t seq_len_input_idx = rnn_enum::kSequenceLength; - if (param.mode != rnn_enum::kLstm) { - seq_len_input_idx -= 1; - } - in_data.push_back(outputs[seq_len_input_idx]); - } - - op.Backward(ctx, out_grad, in_data, out_data, req, in_grad); - }); + MSHADOW_TYPE_SWITCH(itype, IType, { + RNNOp& op = state.get_state>(); + const RNNParam& param = op.param_; + int index = 5; + if (param.state_outputs) { + out_data.push_back(inputs[index++]); + out_grad.push_back(inputs[index++]); + } + + if (param.mode == rnn_enum::kLstm) { + in_data.push_back(inputs[index++]); + if (param.state_outputs) { + out_data.push_back(inputs[index++]); + out_grad.push_back(inputs[index]); + } + } + + + if (param.use_sequence_length) { + size_t seq_len_input_idx = rnn_enum::kSequenceLength; + if (param.mode != rnn_enum::kLstm) { + seq_len_input_idx -= 1; + } + in_data.push_back(outputs[seq_len_input_idx]); + } + + op.Backward(ctx, out_grad, in_data, out_data, req, in_grad); }); + }); } } // namespace op diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index b2ac2f0cb615..6d568c81bc1c 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -27,6 +27,9 @@ #include #include "./rnn-inl.h" +#if MXNET_USE_MKLDNN == 1 +#include "./nn/mkldnn/mkldnn_rnn-inl.h" +#endif // MXNET_USE_MKLDNN == 1 namespace mxnet { namespace op { @@ -190,9 +193,9 @@ inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { DispatchMode wanted_mode = DispatchMode::kFCompute; - #if MXNET_USE_MKLDNN == 1 - wanted_mode = DispatchMode::kFComputeEx; - #endif +#if MXNET_USE_MKLDNN == 1 + wanted_mode = DispatchMode::kFComputeEx; +#endif // MXNET_USE_MKLDNN == 1 return storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode, wanted_mode); @@ -222,432 +225,73 @@ struct RNNGrad { } }; -#if MXNET_USE_MKLDNN == 1 -static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - std::vector in_blobs; - std::vector out_blobs; - std::vector temp_ndarrays_i; - std::vector temp_ndarrays_o; - for (const NDArray& in : inputs) { - if (in.storage_type() == kDefaultStorage) { - temp_ndarrays_i.push_back(in.Reorder2Default()); - in_blobs.emplace_back(temp_ndarrays_i.back().data()); - } else { - in_blobs.emplace_back(in.data()); +static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, + const Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + const RNNParam& param = nnvm::get(attrs.parsed); + OpStatePtr state = OpStatePtr(); + int dtype = in_types[rnn_enum::kData]; + int itype = dtype; + if (param.use_sequence_length) { + size_t seq_len_input_idx = rnn_enum::kSequenceLength; + if (param.mode != rnn_enum::kLstm) { + seq_len_input_idx -= 1; } + itype = in_types[seq_len_input_idx]; } - for (const NDArray& out : outputs) { - if (out.storage_type() == kDefaultStorage) { - temp_ndarrays_o.push_back(out.Reorder2Default()); - out_blobs.emplace_back(temp_ndarrays_o.back().data()); - } else { - out_blobs.emplace_back(out.data()); - } +#if MXNET_USE_MKLDNN == 1 + if ((in_types[0] == mshadow::kFloat32 || in_types[0] == mshadow::kFloat16) + && in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU) { + const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData]; + state = OpStatePtr::Create(param, data_shape[0], + data_shape[1], data_shape[2]); + return state; } - int dtype = in_blobs[rnn_enum::kData].type_flag_; - int itype = in_blobs[inputs.size()-1].type_flag_; - mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); - Stream *s = ctx.get_stream(); - auto cpu_engine = CpuEngine::Get()->get_engine(); +#endif // MXNET_USE_MKLDNN == 1 + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { MSHADOW_TYPE_SWITCH(itype, IType, { - RNNOp& op = state_ptr.get_state>(); - const RNNParam& param = op.param_; - int ngates = 0, nstates = 0; - GetMKLDNNRNNAlgo(param.mode, &ngates, &nstates); - int D = param.bidirectional ? 2 : 1; - Tensor x = in_blobs[rnn_enum::kData].get(s); - int T = x.shape_[0]; - int N = x.shape_[1]; - int I = x.shape_[2]; - int H = param.state_size; - int L = param.num_layers; - - const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode); - if (op.init_mem_ && op.reserve_mem_size_ < r_size) { - op.init_mem_ = false; - } - const size_t weights_version = inputs[rnn_enum::kParams].version(); - if (!op.init_mem_) { - op.mem_space_ = NDArray(TShape({static_cast(r_size)}), op.ctx_, false, dtype); - op.reserve_mem_size_ = r_size; - op.init_mem_ = true; - op.has_cache = false; - // Assign weights_version - op.weights_version = weights_version; - } - // Check if NDArray was changed. - if (op.weights_version != weights_version) { - op.has_cache = false; - op.weights_version = weights_version; + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(param, ctx); + } else { + state = OpStatePtr::Create>(param, ctx); } - - DType* workptr = static_cast(op.mem_space_.data().dptr_); - mkldnn::memory::dims src_layer_tz_0 = {T, N, I}; - mkldnn::memory::dims src_layer_tz = {T, N, D * H}; - mkldnn::memory::dims dst_layer_tz = {T, N, D * H}; - auto dst_layer_md = mkldnn::memory::desc( - { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - if (op.x_memory.size() == 0) { - if (D == 1 && I == H) { - auto user_src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory_n = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory_n); - - mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; - auto user_weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md = mkldnn::memory::desc({ bias_tz }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - DType* weight_layer_n = workptr; // L * I * ngates * H - auto user_weight_layer_memory_n - = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); - - DType* weight_iter_n = weight_layer_n + L * I * ngates * H; // L * H * ngates * H - auto user_weight_iter_memory_n - = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); - - DType* bias_n = weight_iter_n + L * H * ngates * H; // L * ngates * H - auto user_bias_memory_n = - mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); - - auto wx_md_n = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - DType* wx_n = bias_n + L * ngates * H; // L * ngates * I * H - auto wx_memory_n = - mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); - DType* wh_n = wx_n + L * ngates * I * H; // L * ngates * H * H - auto wh_md_n = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_memory_n = - mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); - workptr = wh_n + L * ngates * H * H; - - mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n1 = mkldnn::memory::desc( - { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - for (int l = 0; l < L; l++) { - DType* src_iter_n1 = workptr; // nstates * N * H - auto src_iter_memory_n1 = - mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); - op.concat_iter_memory.push_back(src_iter_memory_n1); - workptr = src_iter_n1 + nstates * N * H; - } - mkldnn::memory::dims src_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n = mkldnn::memory::desc( - { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_n = workptr; // L * nstates * N * H - auto src_iter_memory_n = - mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); - op.concat_iter_memory.push_back(src_iter_memory_n); - op.hcx_memory.push_back(src_iter_memory_n); - DType* dst_layer_n = src_iter_n + L * nstates * N * H; // T * N * D * H - auto dst_layer_memory_n - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); - - mkldnn::memory::dims dst_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc - auto dst_iter_md_n = mkldnn::memory::desc( - { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_n = dst_layer_n + T * N * D * H; // L * nstates * N * H - auto dst_iter_memory_n = - mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); - workptr = dst_iter_n + L * nstates * N * H; - - } else { - auto user_src_layer_md_0 = mkldnn::memory::desc( - { src_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory_0 = mkldnn::memory({ user_src_layer_md_0, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory_0); - - mkldnn::memory::dims weights_layer_tz_0 = {1, D, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz_0 = {1, D, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz_0 = {1, D, ngates, H}; - auto user_weight_layer_md_0 = mkldnn::memory::desc( - { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md_0 = mkldnn::memory::desc( - { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md_0 = mkldnn::memory::desc({ bias_tz_0 }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - - DType* weight_layer_0 = workptr; // D * I * ngates * H - auto user_weight_layer_memory_0 - = mkldnn::memory({ user_weight_layer_md_0, cpu_engine }, weight_layer_0); - op.wx_memory.push_back(user_weight_layer_memory_0); - - DType* weight_iter_0 = weight_layer_0 + D * I * ngates * H; // D * H * ngates * H - auto user_weight_iter_memory_0 - = mkldnn::memory({ user_weight_iter_md_0, cpu_engine }, weight_iter_0); - op.wh_memory.push_back(user_weight_iter_memory_0); - - DType* bias_0 = weight_iter_0 + D * H * ngates * H; // D * ngates * H - auto user_bias_memory_0 = - mkldnn::memory({ user_bias_md_0, cpu_engine }, bias_0); - op.bias_memory.push_back(user_bias_memory_0); - workptr = bias_0 + D * ngates * H; - - auto wx_md_0 = mkldnn::memory::desc( - { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wx_memory_0 = - mkldnn::memory({ wx_md_0, cpu_engine }); - auto wh_md_0 = mkldnn::memory::desc( - { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_memory_0 = - mkldnn::memory({ wh_md_0, cpu_engine }); - if (D == 2) { - DType* wx_0 = workptr; // D * ngates * I * H - wx_memory_0.set_data_handle(wx_0); - DType* wh_0 = wx_0 + D * ngates * I * H; // D * ngates * H * H - wh_memory_0.set_data_handle(wh_0); - workptr = wh_0 + D * ngates * H * H; - } - op.concat_weight_memory.push_back(wx_memory_0); - op.concat_weight_memory.push_back(wh_memory_0); - - mkldnn::memory::dims src_iter_undi_tz_0 = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_undi_md_0 = mkldnn::memory::desc( - { src_iter_undi_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_undi_0 = workptr; // nstates * N * H - auto src_iter_undi_memory_0 = - mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi_0); - op.concat_iter_memory.push_back(src_iter_undi_memory_0); - workptr = src_iter_undi_0 + nstates * N * H; - if (D == 1) { - op.hcx_memory.push_back(src_iter_undi_memory_0); - } else { - DType* src_iter_undi2_0 = workptr; // nstates * N * H - auto src_iter_undi2_memory_0 = - mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi2_0); - op.concat_iter_memory.push_back(src_iter_undi2_memory_0); - - mkldnn::memory::dims src_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc - auto src_iter_md_0 = mkldnn::memory::desc( - { src_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_0 = src_iter_undi2_0 + nstates * N * H; // D * nstates * N * H - auto src_iter_memory_0 = - mkldnn::memory({ src_iter_md_0, cpu_engine }, src_iter_0); - op.concat_iter_memory.push_back(src_iter_memory_0); - op.hcx_memory.push_back(src_iter_memory_0); - workptr = src_iter_0 + D * nstates * N * H; - } - - DType* dst_layer_0 = workptr; // T * N * D * H - auto dst_layer_memory_0 - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_0); - op.y_memory.push_back(dst_layer_memory_0); - - mkldnn::memory::dims dst_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc - auto dst_iter_md_0 = mkldnn::memory::desc( - { dst_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_0 = dst_layer_0 + T * N * D * H; // D * nstates * N * H - auto dst_iter_memory_0 = - mkldnn::memory({ dst_iter_md_0, cpu_engine }, dst_iter_0); - op.hcy_memory.push_back(dst_iter_memory_0); - workptr = dst_iter_0 + D * nstates * N * H; - - // next L - 1 layers - if (L > 1 && D == 1) { - auto user_src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory); - - mkldnn::memory::dims weights_layer_tz = {L - 1, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {L - 1, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L - 1, 1, ngates, H}; - auto user_weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md = mkldnn::memory::desc({ bias_tz }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - - DType* weight_layer_n = workptr; // (L - 1) * H * ngates * H - auto user_weight_layer_memory_n - = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); - - DType* weight_iter_n = weight_layer_n + - (L - 1) * H * ngates * H; // (L - 1) * H * ngates * H - auto user_weight_iter_memory_n - = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); - - DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // (L - 1) * ngates * H - auto user_bias_memory_n = - mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); - - auto wx_md_n = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - DType* wx_n = bias_n + (L - 1) * ngates * H; // (L - 1) * ngates * H * H - auto wx_memory_n = - mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); - DType* wh_n = wx_n + (L - 1) * ngates * H * H; // (L - 1) * ngates * H * H - auto wh_md_n = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_memory_n = - mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); - workptr = wh_n + (L - 1) * ngates * H * H; - - mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n1 = mkldnn::memory::desc( - { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - for (int l = 0; l < L - 1; l++) { - DType* src_iter_n1 = workptr; // nstates * N * H - auto src_iter_memory_n1 = - mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); - op.concat_iter_memory.push_back(src_iter_memory_n1); - workptr = src_iter_n1 + nstates * N * H; - } - mkldnn::memory::dims src_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n = mkldnn::memory::desc( - { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_n = workptr; // (L - 1) * nstates * N * H - auto src_iter_memory_n = - mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); - op.concat_iter_memory.push_back(src_iter_memory_n); - op.hcx_memory.push_back(src_iter_memory_n); - - DType* dst_layer_n = src_iter_n + (L - 1) * nstates * N * H; // T * N * D * H - auto dst_layer_memory_n - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); - - mkldnn::memory::dims dst_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc - auto dst_iter_md_n = mkldnn::memory::desc( - { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_n = dst_layer_n + T * N * D * H; // (L - 1) * nstates * N * H - auto dst_iter_memory_n = - mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); - } - - if (L > 1 && D == 2) { - mkldnn::memory::dims weights_layer_tz = {1, D, H * D, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {1, D, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {1, D, ngates, H}; - auto user_weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md = mkldnn::memory::desc({ bias_tz }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - - auto user_src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory); - - auto wx_md_n = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_md_n = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - - for (int l = 0; l < L; l++) { - DType* weight_layer_n = workptr; // D * (H * D) * ngates * H - auto user_weight_layer_memory_n - = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); - - DType* weight_iter_n = weight_layer_n + - D * (H * D) * ngates * H; // D * H * ngates * H - auto user_weight_iter_memory_n - = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); - - DType* bias_n = weight_iter_n + D * H * ngates * H; // D * ngates * H - auto user_bias_memory_n = - mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); - workptr = bias_n + D * ngates * H; - } - - DType* wx_n = workptr; // D * ngates * (D * H) * H - DType* wh_n = wx_n + D * ngates * (D * H) * H; // D * ngates * H * H - auto wx_memory_n = - mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); - auto wh_memory_n = - mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); - - mkldnn::memory::dims src_iter_undi_tz = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_undi_md = mkldnn::memory::desc( - { src_iter_undi_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_undi = wh_n + D * ngates * H * H; // nstates * N * H - auto src_iter_undi_memory = - mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi); - op.concat_iter_memory.push_back(src_iter_undi_memory_0); - - DType* src_iter_undi2 = src_iter_undi + nstates * N * H; // nstates * N * H - auto src_iter_undi2_memory = - mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi2); - op.concat_iter_memory.push_back(src_iter_undi2_memory); - - mkldnn::memory::dims src_iter_tz = {1, D, nstates, N, H}; // ldsnc - auto src_iter_md = mkldnn::memory::desc( - { src_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter = src_iter_undi2 + nstates * N * H; // D * nstates * N * H - auto src_iter_memory = - mkldnn::memory({ src_iter_md, cpu_engine }, src_iter); - op.concat_iter_memory.push_back(src_iter_memory); - op.hcx_memory.push_back(src_iter_memory); - - DType* dst_layer_n = src_iter + D * nstates * N * H; // T * N * D * H - auto dst_layer_memory_n - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); - - mkldnn::memory::dims dst_iter_tz_n = {1, D, nstates, N, H}; // ldsnc - auto dst_iter_md_n = mkldnn::memory::desc( - { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_n = dst_layer_n + T * N * D * H; // D * nstates * N * H - auto dst_iter_memory_n = - mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); - } - } - } - op.Forward(ctx, in_blobs, req, out_blobs); }); }); + return state; } -static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, const OpContext& ctx, +#if MXNET_USE_MKLDNN == 1 +static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (SupportMKLDNNRNN(inputs[0])) { - RNNStatefulComputeCPU(state_ptr, ctx, inputs, req, outputs); - return; + if ((inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == mshadow::kFloat16) && + inputs[0].shape().ndim() == 3) { + MKLDNNRnnOp& op = state_ptr.get_state(); + op.Forward(ctx, inputs, req, outputs); + } else { + FallBackCompute(RNNStatefulCompute, state_ptr, ctx, inputs, req, outputs); } - int use_mkldnn_rnn = dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1); - dmlc::SetEnv("MXNET_USE_MKLDNN_RNN", 0); - FallBackCompute(RNNStatefulCompute, state_ptr, ctx, inputs, req, outputs); - dmlc::SetEnv("MXNET_USE_MKLDNN_RNN", use_mkldnn_rnn); } -#endif + +static void RNNStatefulGradComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if ((inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == mshadow::kFloat16) && + inputs[0].shape().ndim() == 3) { + MKLDNNRnnOp& op = state_ptr.get_state(); + op.Backward(ctx, inputs, req, outputs); + } else { + FallBackCompute(RNNStatefulGradCompute, state_ptr, ctx, inputs, req, outputs); + } +} +#endif // MXNET_USE_MKLDNN == 1 NNVM_REGISTER_OP(RNN) .add_alias("_npx_rnn") @@ -726,6 +370,16 @@ The definition of GRU here is slightly different from paper but compatible with const RNNParam& params = nnvm::get(attrs.parsed); return ListArguments(params); }) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + const RNNParam& params = nnvm::get(attrs.parsed); + std::vector names{"output"}; + if (params.state_outputs) { + names.emplace_back("state_output"); + if (params.mode == rnn_enum::kLstm) + names.emplace_back("statecell_output"); + } + return names; +}) .set_attr("FInferShape", RNNShape) .set_attr("FInferType", RNNType) .set_attr("FInferStorageType", RNNStorageType) @@ -756,7 +410,12 @@ NNVM_REGISTER_OP(_backward_RNN) .set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", RNNStorageType) .set_attr("FStatefulCompute", RNNStatefulGradCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FStatefulComputeEx", RNNStatefulGradComputeExCPU) +#endif .set_attr("FResourceRequestEx", RNNResourceEx); } // namespace op } // namespace mxnet diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h index 425ea4a3c6ab..e1b4a2b79c0a 100644 --- a/src/operator/rnn_impl.h +++ b/src/operator/rnn_impl.h @@ -44,13 +44,6 @@ namespace mxnet { namespace op { -namespace rnn_enum { - enum RNNOpInputs {kData, kParams, kState, kStateCell, kSequenceLength}; - enum RNNOpOutputs {kOut, kStateOut, kStateCellOut}; - enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru}; - enum RNNOpResource {kTempSpace, kCuDNNDropoutDescSpace}; -} - template inline DType sigmoid(DType x) { return 1.0f / (1.0f + exp(-x)); diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 548225f0496b..ec2670974f49 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -26,6 +26,7 @@ #include "./softmax_output-inl.h" #if MXNET_USE_MKLDNN == 1 #include "./nn/mkldnn/mkldnn_ops-inl.h" +#include "./nn/mkldnn/mkldnn_base-inl.h" #endif namespace mxnet { namespace op { diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h index 002b012bc35e..509d25037ad7 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h +++ b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h @@ -47,10 +47,10 @@ static inline bool IsOutputUInt8(const MKLDNNConvFusionParam& param) { param.alg == mkldnn::algorithm::eltwise_bounded_relu); }; if ((!mkldnn_param.with_sum) && mkldnn_param.with_act) { - CHECK(param.full_conv_param.act_param.alg != mkldnn::algorithm::algorithm_undef); + CHECK(param.full_conv_param.act_param.alg != mkldnn::algorithm::undef); result = IsOutputUInt8Helper(param.full_conv_param.act_param); } else if (mkldnn_param.with_postsum_act) { - CHECK(param.full_conv_param.postsum_act_param.alg != mkldnn::algorithm::algorithm_undef); + CHECK(param.full_conv_param.postsum_act_param.alg != mkldnn::algorithm::undef); result = IsOutputUInt8Helper(param.full_conv_param.postsum_act_param); } return result; diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 886a21b44fc9..e1f9174898c7 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -128,40 +128,37 @@ static void ConvertWeightBias2MKLDNN(const MKLDNNConvFullParam ¶m, NDArray *weight, NDArray *bias, bool has_bias, float data_scale, const std::vector &weight_scales) { MKLDNNStream *stream = MKLDNNStream::Get(); - const auto new_weight = NDArray(fwd_pd.weights_primitive_desc()); + const auto new_weight = NDArray(fwd_pd.weights_desc()); const auto conv_weights_memory = new_weight.GetMKLDNNData(); - primitive_attr weight_attr; + mkldnn::primitive_attr weight_attr; if (weight_scales.size()) { const int weight_mask = (weight_scales.size()) == 1 ? 0 : 1; - weight_attr.set_int_output_round_mode(round_mode::round_nearest); weight_attr.set_output_scales(weight_mask, weight_scales); } auto default_weights_memory = GetWeights(*weight, param.conv_param.num_group); if (default_weights_memory == nullptr) default_weights_memory = weight->GetMKLDNNData(); const auto weight_reorder_pd = - mkldnn::reorder::primitive_desc(default_weights_memory->get_primitive_desc(), - conv_weights_memory->get_primitive_desc(), weight_attr); - stream->RegisterPrim( - mkldnn::reorder(weight_reorder_pd, *default_weights_memory, *conv_weights_memory)); - + mkldnn::reorder::primitive_desc(*default_weights_memory, *conv_weights_memory, weight_attr); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(weight_reorder_pd), + {{MKLDNN_ARG_FROM, *default_weights_memory}, {MKLDNN_ARG_TO, *conv_weights_memory}}); NDArray new_bias; if (has_bias && data_scale) { std::vector bias_scales(weight_scales.size()); for (size_t c = 0; c < weight_scales.size(); ++c) { bias_scales[c] = weight_scales[c] * data_scale; } - new_bias = NDArray(fwd_pd.bias_primitive_desc()); + new_bias = NDArray(fwd_pd.bias_desc()); const auto conv_bias_memory = new_bias.GetMKLDNNData(); const int bias_mask = (bias_scales.size()) == 1 ? 0 : 1; - primitive_attr bias_attr; - bias_attr.set_int_output_round_mode(round_mode::round_nearest); + mkldnn::primitive_attr bias_attr; bias_attr.set_output_scales(bias_mask, bias_scales); auto bias_weights_memory = bias->GetMKLDNNData(); - auto bias_reorder_pd = - mkldnn::reorder::primitive_desc(bias_weights_memory->get_primitive_desc(), - conv_bias_memory->get_primitive_desc(), bias_attr); - stream->RegisterPrim( - mkldnn::reorder(bias_reorder_pd, *bias_weights_memory, *conv_bias_memory)); + const auto bias_reorder_pd = + mkldnn::reorder::primitive_desc(*bias_weights_memory, *conv_bias_memory, bias_attr); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(bias_reorder_pd), + {{MKLDNN_ARG_FROM, *bias_weights_memory}, {MKLDNN_ARG_TO, *conv_bias_memory}}); } stream->Submit(); *weight = new_weight; @@ -186,6 +183,7 @@ class SgMKLDNNConvOperator { nnvm::Symbol subgraph_sym_; MKLDNNConvFusionParam param_; std::shared_ptr fwd_; + mkldnn_args_map_t args_; NDArray cached_weight_; NDArray cached_bias_; float cached_data_min_; @@ -253,22 +251,24 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, auto in_mkl_mem = inputs[in_sum].GetMKLDNNData(); auto out_mkl_mem = outputs[kOut].GetMKLDNNData(); if (outputs[kOut].dtype() == mshadow::kInt32) { - auto mem_desc = in_mkl_mem->get_primitive_desc().desc(); - auto this_dtype = get_mkldnn_type(mshadow::kInt32); - mkldnn::memory::desc omd( - mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims), - this_dtype, static_cast(mem_desc.data.format)); - mkldnn::memory::primitive_desc opd(omd, CpuEngine::Get()->get_engine()); - mkldnn_mem_ptr tmp_mem(new mkldnn::memory(opd, out_mkl_mem->get_data_handle())); + const auto& mem_desc = in_mkl_mem->get_desc(); + const auto this_dtype = get_mkldnn_type(mshadow::kInt32); + auto omd = mem_desc; + omd.data.data_type = static_cast(this_dtype); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(omd, CpuEngine::Get()->get_engine(), + out_mkl_mem->get_data_handle())); MKLDNNStream::Get()->RegisterMem(tmp_mem); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(*in_mkl_mem, *tmp_mem)); + MKLDNNStream::Get()->RegisterPrimArgs( + mkldnn::reorder(*in_mkl_mem, *tmp_mem), + {{MKLDNN_ARG_FROM, *in_mkl_mem}, {MKLDNN_ARG_TO, *tmp_mem}}); output = NDArray(tmp_mem); } else { - mkldnn_mem_ptr tmp_mem( - new mkldnn::memory(in_mkl_mem->get_primitive_desc(), out_mkl_mem->get_data_handle())); - MKLDNNStream::Get()->RegisterMem(tmp_mem); - mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get()); - output = NDArray(tmp_mem); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(in_mkl_mem->get_desc(), + CpuEngine::Get()->get_engine(), + out_mkl_mem->get_data_handle())); + MKLDNNStream::Get()->RegisterMem(tmp_mem); + mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get()); + output = NDArray(tmp_mem); } } } @@ -391,27 +391,25 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, fwd_.reset(new MKLDNNConvForward( full_conv_param, ctx.is_train, data, cached_weight_, has_bias ? &cached_bias_ : nullptr, output)); - ConvertWeightBias2MKLDNN(full_conv_param, fwd_->fwd_pd, &cached_weight_, &cached_bias_, + ConvertWeightBias2MKLDNN(full_conv_param, fwd_->GetPd(), &cached_weight_, &cached_bias_, has_bias, data_scale_, weight_scales_); - fwd_->SetNewMem(*data.GetMKLDNNData(), *cached_weight_.GetMKLDNNData(), - has_bias ? cached_bias_.GetMKLDNNData() : nullptr, - *output.GetMKLDNNData()); + args_[MKLDNN_ARG_SRC] = *data.GetMKLDNNData(); + args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData(); + if (has_bias) args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData(); + args_[MKLDNN_ARG_DST] = *output.GetMKLDNNData(); initialized_ = true; } if (mkldnn_param.with_sum) { - const auto output_mem = output.GetMKLDNNData(); - const auto out_mem_desc = output_mem->get_primitive_desc().desc(); - const auto dst_format = fwd_->fwd_pd.dst_primitive_desc().desc().data.format; - if (out_mem_desc.data.format != dst_format) { - auto tmp_out_mem = output.GetMKLDNNDataReorder(fwd_->fwd_pd.dst_primitive_desc()); - mkldnn::memory::desc data_md( - mkldnn::memory::dims(out_mem_desc.data.dims, - out_mem_desc.data.dims + out_mem_desc.data.ndims), - static_cast(out_mem_desc.data.data_type), - static_cast(dst_format)); - mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine()); - mkldnn_mem_ptr new_out_mem(new mkldnn::memory(pd, output_mem->get_data_handle())); + const auto& output_mem = output.GetMKLDNNData(); + const auto& out_mem_desc = output_mem->get_desc(); + const auto& dst_mem_desc = fwd_->GetPd().dst_desc(); + if (out_mem_desc != dst_mem_desc) { + auto tmp_out_mem = output.GetMKLDNNDataReorder(fwd_->GetPd().dst_desc()); + auto data_md = dst_mem_desc; + data_md.data.data_type = static_cast(out_mem_desc.data.data_type); + mkldnn_mem_ptr new_out_mem(new mkldnn::memory(data_md, CpuEngine::Get()->get_engine(), + output_mem->get_data_handle())); MKLDNNStream::Get()->RegisterMem(new_out_mem); mxnet::MKLDNNCopy(*tmp_out_mem, new_out_mem.get()); output = NDArray(new_out_mem); @@ -419,10 +417,11 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, } if (mkldnn_param.quantized) { - auto data_mem = data.GetMKLDNNDataReorder(fwd_->fwd_pd.src_primitive_desc()); - mkldnn::memory *mem = output.CreateMKLDNNData(fwd_->fwd_pd.dst_primitive_desc()); - fwd_->SetNewMem(*data_mem, *mem); - MKLDNNStream::Get()->RegisterPrim(fwd_->GetFwd()); + auto data_mem = data.GetMKLDNNDataReorder(fwd_->GetPd().src_desc()); + mkldnn::memory *mem = output.CreateMKLDNNData(fwd_->GetPd().dst_desc()); + args_[MKLDNN_ARG_SRC] = *data_mem; + args_[MKLDNN_ARG_DST] = *mem; + MKLDNNStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_); MKLDNNStream::Get()->Submit(); } else { std::vector new_inputs; @@ -441,9 +440,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, } if (mkldnn_param.with_sum) { auto out = const_cast(outputs[kOut]); - auto format = static_cast( - fwd_->fwd_pd.dst_primitive_desc().desc().data.format); - out.UpdateMKLDNNMemDesc(format); + out.UpdateMKLDNNMemDesc(fwd_->GetPd().dst_desc()); } } diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h b/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h index cb355dab9abe..9a09d91ae5d0 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h @@ -55,7 +55,7 @@ static inline mkldnn::algorithm GetMKLDNNEltwiseAlgo(const std::string op_name) else LOG(FATAL) << "Unsupported eltwise fusion op: " << op_name; - return mkldnn::algorithm::algorithm_undef; + return mkldnn::algorithm::undef; } static inline bool IsOutputUint8(const MKLDNNFCFullParam& full_param) { diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc index d0d2b51918b1..269017ea6a03 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc @@ -34,25 +34,33 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN) MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); +#endif // MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 1 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty); - - +#endif // MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 1 MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE) .set_attr("context", Context::CPU()); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty) .set_attr("quantize", true); +#endif // MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 1 + MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty) .set_attr("quantize", true); - +#endif // MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 1 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeProperty); +#endif // MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 1 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty); - +#endif // MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 1 } // namespace op } // namespace mxnet - #endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 7c35c44305a8..2a834bb9dc55 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -217,7 +217,7 @@ static void CopyEx(const nnvm::NodeAttrs& attrs, FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, inputs, req, outputs); return; } -#endif +#endif // MXNET_USE_MKLDNN == 1 UnaryOp::IdentityComputeEx(attrs, ctx, inputs, req, outputs); } @@ -238,7 +238,7 @@ static inline bool CopyStorageType(const nnvm::NodeAttrs& attrs, && out_attrs->at(0) == kDefaultStorage) { *dispatch_mode = DispatchMode::kFComputeEx; } -#endif +#endif // MXNET_USE_MKLDNN == 1 return ret; } @@ -253,7 +253,7 @@ MXNET_OPERATOR_REGISTER_UNARY(_copy) return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("TIsMKLDNN", true) -#endif +#endif // MXNET_USE_MKLDNN == 1 .set_attr("FInplaceIdentity", [](const NodeAttrs& attrs){ return std::vector{true}; @@ -275,7 +275,7 @@ NNVM_REGISTER_OP(_backward_copy) .set_attr("TIsMKLDNN", true) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; -}) +}) // MXNET_USE_MKLDNN == 1 #endif .set_attr("FInplaceIdentity", [](const NodeAttrs& attrs){ diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 7c3005e583b6..801e4e7126b4 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -459,6 +459,10 @@ struct ExpandDimParam : public dmlc::Parameter { "the input `NDArray`'s dimension is `ndim`, the range of " "the inserted axis is `[-ndim, ndim]`"); } + + bool operator==(const ExpandDimParam &other) const { + return this->axis == other.axis; + } }; @@ -3040,6 +3044,16 @@ struct hash { return ret; } }; + +template<> +struct hash { + size_t operator()(const mxnet::op::ExpandDimParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.axis); + return ret; + } +}; + } // namespace std #endif // MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_ diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 99fba15d47ba..0f63061d7c09 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -25,9 +25,11 @@ // this will be invoked by gcc and compile CPU version #include "./matrix_op-inl.h" #include "./elemwise_unary_op.h" +#if MXNET_USE_MKLDNN == 1 #include "../nn/mkldnn/mkldnn_ops-inl.h" #include "../nn/mkldnn/mkldnn_base-inl.h" #include "../nn/mkldnn/mkldnn_slice-inl.h" +#endif namespace mxnet { namespace op { @@ -114,7 +116,7 @@ static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); // If inputs are supposed to be in MKLDNN format and - // MKLDNNsupport the data type or the shape. Then convert + // MKLDNN support the data type or the shape. Then convert // it to the output format and shape MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); } @@ -134,66 +136,42 @@ inline static bool ReshapeStorageType(const nnvm::NodeAttrs& attrs, NNVM_REGISTER_OP(Reshape) .add_alias("reshape") .describe(R"code(Reshapes the input array. - .. note:: ``Reshape`` is deprecated, use ``reshape`` - Given an array and a shape, this function returns a copy of the array in the new shape. The shape is a tuple of integers such as (2,3,4). The size of the new shape should be same as the size of the input array. - Example:: - reshape([1,2,3,4], shape=(2,2)) = [[1,2], [3,4]] - Some dimensions of the shape can take special values from the set {0, -1, -2, -3, -4}. The significance of each is explained below: - - ``0`` copy this dimension from the input to the output shape. - Example:: - - input shape = (2,3,4), shape = (4,0,2), output shape = (4,3,2) - input shape = (2,3,4), shape = (2,0,0), output shape = (2,3,4) - - ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions keeping the size of the new array same as that of the input array. At most one dimension of shape can be -1. - Example:: - - input shape = (2,3,4), shape = (6,1,-1), output shape = (6,1,4) - input shape = (2,3,4), shape = (3,-1,8), output shape = (3,1,8) - input shape = (2,3,4), shape=(-1,), output shape = (24,) - - ``-2`` copy all/remainder of the input dimensions to the output shape. - Example:: - - input shape = (2,3,4), shape = (-2,), output shape = (2,3,4) - input shape = (2,3,4), shape = (2,-2), output shape = (2,3,4) - input shape = (2,3,4), shape = (-2,1,1), output shape = (2,3,4,1,1) - - ``-3`` use the product of two consecutive dimensions of the input shape as the output dimension. - Example:: - - input shape = (2,3,4), shape = (-3,4), output shape = (6,4) - input shape = (2,3,4,5), shape = (-3,-3), output shape = (6,20) - input shape = (2,3,4), shape = (0,-3), output shape = (2,12) - input shape = (2,3,4), shape = (-3,-2), output shape = (6,4) - - ``-4`` split one dimension of the input into two dimensions passed subsequent to -4 in shape (can contain -1). - Example:: - - input shape = (2,3,4), shape = (-4,1,2,-2), output shape =(1,2,3,4) - input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4) - If the argument `reverse` is set to 1, then the special values are inferred from right to left. - Example:: - - without reverse=1, for input shape = (10,5,4), shape = (-1,0), output shape would be (40,5) - with reverse=1, output shape will be (50,4). - )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -230,7 +208,7 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); // If inputs are supposed to be in MKLDNN format and - // MKLDNNsupport the data type or the shape. Then convert + // MKLDNN support the data type or the shape. Then convert // it to the output format and shape MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); } @@ -251,17 +229,12 @@ NNVM_REGISTER_OP(Flatten) .add_alias("flatten") .add_alias("_npx_batch_flatten") .describe(R"code(Flattens the input array into a 2-D array by collapsing the higher dimensions. - .. note:: `Flatten` is deprecated. Use `flatten` instead. - For an input array with shape ``(d1, d2, ..., dk)``, `flatten` operation reshapes the input array into an output array of shape ``(d1, d2*...*dk)``. - Note that the behavior of this function is different from numpy.ndarray.flatten, which behaves similar to mxnet.ndarray.reshape((-1,)). - Example:: - x = [[ [1,2,3], [4,5,6], @@ -271,10 +244,8 @@ Example:: [4,5,6], [7,8,9] ]], - flatten(x) = [[ 1., 2., 3., 4., 5., 6., 7., 8., 9.], [ 1., 2., 3., 4., 5., 6., 7., 8., 9.]] - )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -334,30 +305,21 @@ inline static bool TransposeStorageType(const nnvm::NodeAttrs& attrs, NNVM_REGISTER_OP(transpose) .describe(R"code(Permutes the dimensions of an array. - Examples:: - x = [[ 1, 2], [ 3, 4]] - transpose(x) = [[ 1., 3.], [ 2., 4.]] - x = [[[ 1., 2.], [ 3., 4.]], - [[ 5., 6.], [ 7., 8.]]] - transpose(x) = [[[ 1., 5.], [ 3., 7.]], - [[ 2., 6.], [ 4., 8.]]] - transpose(x, axes=(1,0,2)) = [[[ 1., 2.], [ 5., 6.]], - [[ 3., 4.], [ 7., 8.]]] )code" ADD_FILELINE) @@ -395,13 +357,36 @@ Examples:: .add_arguments(TransposeParam::__FIELDS__()); +#if MXNET_USE_MKLDNN == 1 +static void ExpandDimEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + // If inputs are supposed to be in MKLDNN format and + // MKLDNN support the data type or the shape. Then convert + // it to the output format and shape + MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); +} + +inline static bool ExpandDimStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +} +#endif + NNVM_REGISTER_OP(expand_dims) .add_alias("_npi_expand_dims") .describe(R"code(Inserts a new axis of size 1 into the array shape - For example, given ``x`` with shape ``(2,3,4)``, then ``expand_dims(x, axis=1)`` will return a new array with shape ``(2,1,3,4)``. - )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -418,6 +403,14 @@ will return a new array with shape ``(2,1,3,4)``. }) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_reshape"}) .set_attr("FCompute", UnaryOp::IdentityCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", ExpandDimEx) +.set_attr("FInferStorageType", ExpandDimStorageType) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .add_argument("data", "NDArray-or-Symbol", "Source input") .add_arguments(ExpandDimParam::__FIELDS__()); @@ -449,44 +442,33 @@ NNVM_REGISTER_OP(slice) MXNET_ADD_SPARSE_OP_ALIAS(slice) .add_alias("crop") .describe(R"code(Slices a region of the array. - .. note:: ``crop`` is deprecated. Use ``slice`` instead. - This function returns a sliced array between the indices given by `begin` and `end` with the corresponding `step`. - For an input array of ``shape=(d_0, d_1, ..., d_n-1)``, slice operation with ``begin=(b_0, b_1...b_m-1)``, ``end=(e_0, e_1, ..., e_m-1)``, and ``step=(s_0, s_1, ..., s_m-1)``, where m <= n, results in an array with the shape ``(|e_0-b_0|/|s_0|, ..., |e_m-1-b_m-1|/|s_m-1|, d_m, ..., d_n-1)``. - The resulting array's *k*-th dimension contains elements from the *k*-th dimension of the input array starting from index ``b_k`` (inclusive) with step ``s_k`` until reaching ``e_k`` (exclusive). - If the *k*-th elements are `None` in the sequence of `begin`, `end`, and `step`, the following rule will be used to set default values. If `s_k` is `None`, set `s_k=1`. If `s_k > 0`, set `b_k=0`, `e_k=d_k`; else, set `b_k=d_k-1`, `e_k=-1`. - The storage type of ``slice`` output depends on storage types of inputs - - slice(csr) = csr - otherwise, ``slice`` generates output with default storage - .. note:: When input data storage type is csr, it only supports step=(), or step=(None,), or step=(1,) to generate a csr output. For other step parameter values, it falls back to slicing a dense tensor. - Example:: - x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] - slice(x, begin=(0,1), end=(2,4)) = [[ 2., 3., 4.], [ 6., 7., 8.]] slice(x, begin=(None, 0), end=(None, 3), step=(-1, 2)) = [[9., 11.], @@ -566,23 +548,17 @@ NNVM_REGISTER_OP(_slice_assign_scalar) NNVM_REGISTER_OP(slice_axis) .describe(R"code(Slices along a given axis. - Returns an array slice along a given `axis` starting from the `begin` index to the `end` index. - Examples:: - x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] - slice_axis(x, axis=0, begin=1, end=3) = [[ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] - slice_axis(x, axis=1, begin=0, end=2) = [[ 1., 2.], [ 5., 6.], [ 9., 10.]] - slice_axis(x, axis=1, begin=-3, end=-1) = [[ 2., 3.], [ 6., 7.], [ 10., 11.]] @@ -606,46 +582,31 @@ NNVM_REGISTER_OP(_backward_slice_axis) NNVM_REGISTER_OP(slice_like) .describe(R"code(Slices a region of the array like the shape of another array. - This function is similar to ``slice``, however, the `begin` are always `0`s and `end` of specific axes are inferred from the second input `shape_like`. - Given the second `shape_like` input of ``shape=(d_0, d_1, ..., d_n-1)``, a ``slice_like`` operator with default empty `axes`, it performs the following operation: - `` out = slice(input, begin=(0, 0, ..., 0), end=(d_0, d_1, ..., d_n-1))``. - When `axes` is not empty, it is used to speficy which axes are being sliced. - Given a 4-d input data, ``slice_like`` operator with ``axes=(0, 2, -1)`` will perform the following operation: - `` out = slice(input, begin=(0, 0, 0, 0), end=(d_0, None, d_2, d_3))``. - Note that it is allowed to have first and second input with different dimensions, however, you have to make sure the `axes` are specified and not exceeding the dimension limits. - For example, given `input_1` with ``shape=(2,3,4,5)`` and `input_2` with ``shape=(1,2,3)``, it is not allowed to use: - `` out = slice_like(a, b)`` because ndim of `input_1` is 4, and ndim of `input_2` is 3. - The following is allowed in this situation: - `` out = slice_like(a, b, axes=(0, 2))`` - Example:: - x = [[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.]] - y = [[ 0., 0., 0.], [ 0., 0., 0.]] - slice_like(x, y) = [[ 1., 2., 3.] [ 5., 6., 7.]] slice_like(x, y, axes=(0, 1)) = [[ 1., 2., 3.] @@ -691,23 +652,15 @@ NNVM_REGISTER_OP(clip) MXNET_ADD_SPARSE_OP_ALIAS(clip) .add_alias("_npi_clip") .describe(R"code(Clips (limits) the values in an array. - Given an interval, values outside the interval are clipped to the interval edges. Clipping ``x`` between `a_min` and `a_max` would be:: - .. math:: - clip(x, a_min, a_max) = \max(\min(x, a_max), a_min)) - Example:: - x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - clip(x,1,8) = [ 1., 1., 2., 3., 4., 5., 6., 7., 8., 8.] - The storage type of ``clip`` output depends on storage types of inputs and the a_min, a_max \ parameter values: - - clip(default) = default - clip(row_sparse, a_min <= 0, a_max >= 0) = row_sparse - clip(csr, a_min <= 0, a_max >= 0) = csr @@ -715,7 +668,6 @@ parameter values: - clip(row_sparse, a_min > 0, a_max > 0) = default - clip(csr, a_min < 0, a_max < 0) = csr - clip(csr, a_min > 0, a_max > 0) = csr - )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -769,28 +721,20 @@ NNVM_REGISTER_OP(_backward_clip) NNVM_REGISTER_OP(repeat) .add_alias("_np_repeat") .describe(R"code(Repeats elements of an array. - By default, ``repeat`` flattens the input array into 1-D and then repeats the elements:: - x = [[ 1, 2], [ 3, 4]] - repeat(x, repeats=2) = [ 1., 1., 2., 2., 3., 3., 4., 4.] - The parameter ``axis`` specifies the axis along which to perform repeat:: - repeat(x, repeats=2, axis=1) = [[ 1., 1., 2., 2.], [ 3., 3., 4., 4.]] - repeat(x, repeats=2, axis=0) = [[ 1., 2.], [ 1., 2.], [ 3., 4.], [ 3., 4.]] - repeat(x, repeats=2, axis=-1) = [[ 1., 1., 2., 2.], [ 3., 3., 4., 4.]] - )code" ADD_FILELINE) .set_num_outputs(1) .set_num_inputs(1) @@ -820,35 +764,25 @@ NNVM_REGISTER_OP(_backward_repeat) NNVM_REGISTER_OP(tile) .add_alias("_npi_tile") .describe(R"code(Repeats the whole array multiple times. - If ``reps`` has length *d*, and input array has dimension of *n*. There are three cases: - - **n=d**. Repeat *i*-th dimension of the input by ``reps[i]`` times:: - x = [[1, 2], [3, 4]] - tile(x, reps=(2,3)) = [[ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.], [ 1., 2., 1., 2., 1., 2.], [ 3., 4., 3., 4., 3., 4.]] - - **n>d**. ``reps`` is promoted to length *n* by pre-pending 1's to it. Thus for an input shape ``(2,3)``, ``repos=(2,)`` is treated as ``(1,2)``:: - - tile(x, reps=(2,)) = [[ 1., 2., 1., 2.], [ 3., 4., 3., 4.]] - - **n) .set_num_inputs(1) diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h index 87df39a2754d..286496108128 100644 --- a/tests/cpp/include/test_core_op.h +++ b/tests/cpp/include/test_core_op.h @@ -79,7 +79,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer keys.emplace_back(i_iter->first.c_str()); values.emplace_back(i_iter->second.c_str()); } - return imperative::ParseAttrs(op, op->num_inputs, count, &keys[0], &values[0]); + return imperative::ParseAttrs(op, op->num_inputs, count, keys.data(), values.data()); } /*! diff --git a/tests/cpp/include/test_mkldnn.h b/tests/cpp/include/test_mkldnn.h index f1682772a14a..1466c992fca9 100644 --- a/tests/cpp/include/test_mkldnn.h +++ b/tests/cpp/include/test_mkldnn.h @@ -37,29 +37,29 @@ using namespace mxnet; -inline static mkldnn::memory::primitive_desc GetMemPD(const mxnet::TShape s, int dtype, - mkldnn::memory::format format) { +inline static mkldnn::memory::desc GetMemDesc(const mxnet::TShape s, const int dtype, + const mkldnn::memory::format_tag format_tag) { mkldnn::memory::dims dims(s.ndim()); for (size_t i = 0; i < dims.size(); i++) dims[i] = s[i]; - mkldnn::memory::desc desc{dims, get_mkldnn_type(dtype), format}; - return mkldnn::memory::primitive_desc(desc, CpuEngine::Get()->get_engine()); + mkldnn::memory::desc desc{dims, get_mkldnn_type(dtype), format_tag}; + return desc; } -inline static mkldnn::memory::primitive_desc GetExpandedMemPD( - mkldnn::memory::primitive_desc pd, float scale, int dim = 0) { - CHECK(dim < pd.desc().data.ndims) << "dimension cannot be larger than total dimensions of input"; - mxnet::TShape s(pd.desc().data.ndims, -1); - for (size_t i = 0; i < pd.desc().data.ndims; i++) - s[i] = pd.desc().data.dims[i]; - s[dim] = static_cast(s[dim] * scale); - return GetMemPD(s, mshadow::DataType::kFlag, - static_cast(pd.desc().data.format)); +inline static mkldnn::memory::desc GetExpandedMemDesc( + mkldnn::memory::desc md, const float scale, const int dim = 0) { + CHECK(dim < md.data.ndims) << "dimension cannot be larger than total dimensions of input"; + mxnet::TShape s(md.data.ndims, -1); + for (size_t i = 0; i < md.data.ndims; i++) + s[i] = md.data.dims[i]; + s[dim] = static_cast(s[dim] * scale); + return GetMemDesc(s, mshadow::DataType::kFlag, + static_cast(GetDefaultFormat(md))); } struct TestArrayShapes { std::vector shapes; - std::vector pds; + std::vector mds; }; // Init arrays with the default layout. @@ -78,17 +78,17 @@ inline static void InitDefaultArray(NDArray *arr, bool is_rand = false) { // Init arrays with the specified layout. -inline static void InitMKLDNNArray(NDArray *arr, const mkldnn::memory::primitive_desc &pd, - bool is_rand = false) { +inline static void InitMKLDNNArray(NDArray *arr, const mkldnn::memory::desc &desc, + bool is_rand = false) { InitDefaultArray(arr, is_rand); - arr->MKLDNNDataReorderAsync(pd); + arr->MKLDNNDataReorderAsync(desc); arr->WaitToRead(); } -inline static bool IsSameShape(mkldnn::memory::primitive_desc pd, mxnet::TShape shape) { - if (pd.desc().data.ndims != shape.ndim()) return false; +inline static bool IsSameShape(const mkldnn::memory::desc &desc, const mxnet::TShape &shape) { + if (desc.data.ndims != shape.ndim()) return false; for (size_t i = 0; i < shape.ndim(); i++) - if (pd.desc().data.dims[i] != shape[i]) return false; + if (desc.data.dims[i] != shape[i]) return false; return true; } @@ -97,81 +97,81 @@ inline static bool IsSameShape(mkldnn::memory::primitive_desc pd, mxnet::TShape // it's specific for certain array shapes. It covers at least one special format // for each of the formats: nchw, oihw, goihw. // To test the logic of the code in NDArray, these formats should be enough. -inline static std::vector GetMKLDNNFormat(size_t num_dims, int dtype) { +inline static std::vector GetMKLDNNFormat(size_t num_dims, int dtype) { if (num_dims == 4) { mkldnn::memory::dims data_dims{1, 3, 224, 224}; mkldnn::memory::desc data_md{data_dims, get_mkldnn_type(dtype), - mkldnn::memory::format::any}; + mkldnn::memory::format_tag::any}; mkldnn::memory::dims weight_dims{96, 3, 11, 11}; mkldnn::memory::desc weight_md{weight_dims, get_mkldnn_type(dtype), - mkldnn::memory::format::any}; + mkldnn::memory::format_tag::any}; mkldnn::memory::dims output_dims{1, 96, 54, 54}; mkldnn::memory::desc out_md{output_dims, get_mkldnn_type(dtype), - mkldnn::memory::format::any}; + mkldnn::memory::format_tag::any}; mkldnn::memory::dims strides{4, 4}; mkldnn::memory::dims padding{0, 0}; mkldnn::convolution_forward::desc desc(mkldnn::prop_kind::forward_training, mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, - padding, padding, mkldnn::padding_kind::zero); + padding, padding); mkldnn::convolution_forward::primitive_desc pd(desc, CpuEngine::Get()->get_engine()); - while (pd.dst_primitive_desc().get_size() != GetMemDescSize(out_md) || - pd.src_primitive_desc().get_size() != GetMemDescSize(data_md) || - pd.weights_primitive_desc().get_size() != GetMemDescSize(weight_md)) { + while (pd.dst_desc().get_size() != GetMemDescSize(out_md) || + pd.src_desc().get_size() != GetMemDescSize(data_md) || + pd.weights_desc().get_size() != GetMemDescSize(weight_md)) { CHECK(pd.next_impl()) << "No implementation"; } - std::vector ret(1); - ret[0] = static_cast(pd.dst_primitive_desc().desc().data.format); - printf("format: %d \n", ret[0]); + std::vector ret(1); + ret[0] = static_cast(GetDefaultFormat(pd.dst_desc())); + printf("format: %d \n", static_cast(ret[0])); return ret; } else if (num_dims == 5) { mkldnn::memory::dims data_dims{1, 32, 112, 112}; mkldnn::memory::desc data_md{data_dims, get_mkldnn_type(dtype), - mkldnn::memory::format::any}; + mkldnn::memory::format_tag::any}; mkldnn::memory::dims weight_dims{32, 1, 1, 3, 3}; mkldnn::memory::desc weight_md{weight_dims, get_mkldnn_type(dtype), - mkldnn::memory::format::any}; + mkldnn::memory::format_tag::any}; mkldnn::memory::dims output_dims{1, 32, 112, 112}; mkldnn::memory::desc out_md{output_dims, get_mkldnn_type(dtype), - mkldnn::memory::format::any}; + mkldnn::memory::format_tag::any}; mkldnn::memory::dims strides{1, 1}; mkldnn::memory::dims padding{1, 1}; mkldnn::convolution_forward::desc desc(mkldnn::prop_kind::forward_training, mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, - padding, padding, mkldnn::padding_kind::zero); + padding, padding); mkldnn::convolution_forward::primitive_desc pd(desc, CpuEngine::Get()->get_engine()); - while (pd.dst_primitive_desc().get_size() != GetMemDescSize(out_md) || - pd.src_primitive_desc().get_size() != GetMemDescSize(data_md) || - pd.weights_primitive_desc().get_size() != GetMemDescSize(weight_md)) { + while (pd.dst_desc().get_size() != GetMemDescSize(out_md) || + pd.src_desc().get_size() != GetMemDescSize(data_md) || + pd.weights_desc().get_size() != GetMemDescSize(weight_md)) { CHECK(pd.next_impl()) << "No implementation"; } - std::vector ret(1); - ret[0] = static_cast(pd.weights_primitive_desc().desc().data.format); - printf("format: %d\n", ret[0]); + std::vector ret(1); + ret[0] = static_cast(GetDefaultFormat(pd.weights_desc())); + printf("format: %d\n", static_cast(ret[0])); return ret; } else { - return std::vector(); + return std::vector(); } } inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = false) { int dtype = mshadow::DataType::kFlag; mxnet::ShapeVector shapes; - std::vector pds; + std::vector mds; { // 1D mxnet::TShape s(1, -1); s[0] = 279936; shapes.push_back(s); - pds.push_back(GetMemPD(s, dtype, mkldnn::memory::format::x)); + mds.push_back(GetMemDesc(s, dtype, mkldnn::memory::format_tag::x)); s[0] = 34848; shapes.push_back(s); - pds.push_back(GetMemPD(s, dtype, mkldnn::memory::format::x)); + mds.push_back(GetMemDesc(s, dtype, mkldnn::memory::format_tag::x)); } { // 2D @@ -179,27 +179,27 @@ inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = fals s[0] = 96; s[1] = 2916; shapes.push_back(s); - pds.push_back(GetMemPD(s, dtype, mkldnn::memory::format::nc)); + mds.push_back(GetMemDesc(s, dtype, mkldnn::memory::format_tag::nc)); s[0] = 96; s[1] = 363; shapes.push_back(s); - pds.push_back(GetMemPD(s, dtype, mkldnn::memory::format::nc)); + mds.push_back(GetMemDesc(s, dtype, mkldnn::memory::format_tag::nc)); } { // 4D mxnet::TShape s1(4, -1); s1[0] = 10; s1[1] = 96; s1[2] = 54; s1[3] = 54; shapes.push_back(s1); - pds.push_back(GetMemPD(s1, dtype, mkldnn::memory::format::nchw)); + mds.push_back(GetMemDesc(s1, dtype, mkldnn::memory::format_tag::nchw)); mxnet::TShape s2(4, -1); s2[0] = 96; s2[1] = 3; s2[2] = 11; s2[3] = 11; shapes.push_back(s2); - pds.push_back(GetMemPD(s2, dtype, mkldnn::memory::format::oihw)); + mds.push_back(GetMemDesc(s2, dtype, mkldnn::memory::format_tag::oihw)); - std::vector formats = GetMKLDNNFormat(4, dtype); + std::vector formats = GetMKLDNNFormat(4, dtype); if (!spatial_data_format) { - pds.push_back(GetMemPD(s1, dtype, formats[0])); + mds.push_back(GetMemDesc(s1, dtype, formats[0])); } } { @@ -207,17 +207,17 @@ inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = fals mxnet::TShape s(5, -1); s[0] = 96; s[1] = 1; s[2] = 3; s[3] = 11; s[4] = 11; shapes.push_back(s); - pds.push_back(GetMemPD(s, dtype, mkldnn::memory::format::goihw)); + mds.push_back(GetMemDesc(s, dtype, mkldnn::memory::format_tag::goihw)); - std::vector formats = GetMKLDNNFormat(5, dtype); + std::vector formats = GetMKLDNNFormat(5, dtype); if (!spatial_data_format) { - pds.push_back(GetMemPD(s, dtype, formats[0])); + mds.push_back(GetMemDesc(s, dtype, formats[0])); } } TestArrayShapes ret; ret.shapes = shapes; - ret.pds = pds; + ret.mds = mds; return ret; } @@ -266,7 +266,7 @@ inline NDArray CreateKernelNDArray(mxnet::TShape kernel, int num_filters, mxnet: target_shape[3] = kernel[1]; int dtype = mshadow::DataType::kFlag; NDArray arr(target_shape, Context()); - auto pd = GetMemPD(target_shape, dtype, mkldnn::memory::format::nchw); + auto pd = GetMemDesc(target_shape, dtype, mkldnn::memory::format_tag::nchw); InitMKLDNNArray(&arr, pd); return arr; } @@ -274,7 +274,7 @@ inline NDArray CreateKernelNDArray(mxnet::TShape kernel, int num_filters, mxnet: inline NDArray CreateBiasNDArray(mxnet::TShape target_shape) { int dtype = mshadow::DataType::kFlag; NDArray arr(target_shape, Context()); - auto pd = GetMemPD(target_shape, dtype, mkldnn::memory::format::x); + auto pd = GetMemDesc(target_shape, dtype, mkldnn::memory::format_tag::x); InitMKLDNNArray(&arr, pd); return arr; } @@ -333,10 +333,10 @@ inline std::vector GetTestInputArrays( std::vector scale = {1}, bool spatial_data_format = false) { TestArrayShapes tas = GetTestArrayShapes(spatial_data_format); std::vector shapes = tas.shapes; - std::vector pds = tas.pds; + std::vector mds = tas.mds; std::vector in_arrs; - std::string desc; + std::string desc_str; int slice_amount = scale[0]; for (auto shape : shapes) { @@ -362,60 +362,60 @@ inline std::vector GetTestInputArrays( } - for (auto pd : pds) { + for (auto md : mds) { for (size_t dim = 0; dim < scale.size(); ++dim) { // preserve if matching layout else just expand on 0 dim - if (shape.ndim() == pd.desc().data.ndims) - pd = GetExpandedMemPD(pd, scale[dim], dim); + if (shape.ndim() == md.data.ndims) + md = GetExpandedMemDesc(md, scale[dim], dim); else - pd = GetExpandedMemPD(pd, scale[dim]); + md = GetExpandedMemDesc(md, scale[dim]); } - if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t)) + if (shape.Size() != md.get_size() / sizeof(mshadow::default_real_t)) continue; // Type 2, 3. arr = NDArray(shape, Context()); - if (shape.ndim() == pd.desc().data.ndims && IsSameShape(pd, shape) + if (shape.ndim() == md.data.ndims && IsSameShape(md, shape) && types & ArrayTypes::MKLDNN) { - desc = "MKLDNN NDArray"; - InitMKLDNNArray(&arr, pd, rand); - in_arrs.emplace_back(arr, desc); - } else if (shape.ndim() == pd.desc().data.ndims && !IsSameShape(pd, shape) + desc_str = "MKLDNN NDArray"; + InitMKLDNNArray(&arr, md, rand); + in_arrs.emplace_back(arr, desc_str); + } else if (shape.ndim() == md.data.ndims && !IsSameShape(md, shape) && types & ArrayTypes::MKLDNNDiffShape) { - desc = "MKLDNN NDArray with different shape"; - InitMKLDNNArray(&arr, pd, rand); - in_arrs.emplace_back(arr, desc); - } else if (shape.ndim() != pd.desc().data.ndims && types & ArrayTypes::MKLDNNDiffDim) { + desc_str = "MKLDNN NDArray with different shape"; + InitMKLDNNArray(&arr, md, rand); + in_arrs.emplace_back(arr, desc_str); + } else if (shape.ndim() != md.data.ndims && types & ArrayTypes::MKLDNNDiffDim) { std::stringstream ss; ss << "MKLDNN NDArray with different dim " << - shape.ndim() << "/" << pd.desc().data.ndims; - desc = ss.str(); - InitMKLDNNArray(&arr, pd, rand); - in_arrs.emplace_back(arr, desc); + shape.ndim() << "/" << md.data.ndims; + desc_str = ss.str(); + InitMKLDNNArray(&arr, md, rand); + in_arrs.emplace_back(arr, desc_str); } // Type 5, 6. arr = NDArray(shape, Context()); - if (shape.ndim() == pd.desc().data.ndims && IsSameShape(pd, shape) + if (shape.ndim() == md.data.ndims && IsSameShape(md, shape) && types & ArrayTypes::MKLDNNReshaped) { - desc = "Reshaped MKLDNN NDArray"; - InitMKLDNNArray(&arr, pd, rand); - in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc); - } else if (shape.ndim() == pd.desc().data.ndims && !IsSameShape(pd, shape) + desc_str = "Reshaped MKLDNN NDArray"; + InitMKLDNNArray(&arr, md, rand); + in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc_str); + } else if (shape.ndim() == md.data.ndims && !IsSameShape(md, shape) && types & ArrayTypes::MKLDNNReshapedDiffShape) { - desc = "Reshaped MKLDNN NDArray with different shape"; - InitMKLDNNArray(&arr, pd, rand); - in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc); - } else if (shape.ndim() != pd.desc().data.ndims + desc_str = "Reshaped MKLDNN NDArray with different shape"; + InitMKLDNNArray(&arr, md, rand); + in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc_str); + } else if (shape.ndim() != md.data.ndims && types & ArrayTypes::MKLDNNReshapedDiffDim) { std::stringstream ss; ss << "MKLDNN NDArray with different dim " << - shape.ndim() << "/" << pd.desc().data.ndims; - desc = ss.str(); - InitMKLDNNArray(&arr, pd, rand); - in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc); + shape.ndim() << "/" << md.data.ndims; + desc_str = ss.str(); + InitMKLDNNArray(&arr, md, rand); + in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc_str); } } } @@ -444,7 +444,7 @@ inline std::vector GetTestInputArrays( */ inline std::vector GetTestOutputArrays( const mxnet::TShape &shp, - const std::vector &pds, + const std::vector &mds, std::vectorscale = {1}, bool rand = true, int types = ArrayTypes::All) { mxnet::TShape shape = shp; @@ -452,7 +452,7 @@ inline std::vector GetTestOutputArrays( shape[dim] = static_cast(shape[dim] * scale[dim]); std::vector in_arrs; - std::string desc; + std::string desc_str; // Type 1. NDArray arr(shape, Context()); @@ -500,30 +500,30 @@ inline std::vector GetTestOutputArrays( in_arrs.emplace_back(arr3.Slice(1, shape[0] + 1), "Reused+Reshaped NDArray"); } - for (auto pd : pds) { - if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t)) + for (auto md : mds) { + if (shape.Size() != md.get_size() / sizeof(mshadow::default_real_t)) continue; - if (scale.size() > pd.desc().data.ndims) + if (scale.size() > md.data.ndims) continue; for (int dim = 0; dim < scale.size(); dim++) - pd = GetExpandedMemPD(pd, scale[dim]); + md = GetExpandedMemDesc(md, scale[dim]); // Type 2, 3. arr = NDArray(shape, Context()); - desc = "MKLDNN NDArray"; - if (shape.ndim() != pd.desc().data.ndims) { + desc_str = "MKLDNN NDArray"; + if (shape.ndim() != md.data.ndims) { std::stringstream ss; ss << "MKLDNN NDArray with different memory layout " - << shape.ndim() << "/" << pd.desc().data.ndims; - desc = ss.str(); + << shape.ndim() << "/" << md.data.ndims; + desc_str = ss.str(); } - if ((types & ArrayTypes::MKLDNN && shape.ndim() == pd.desc().data.ndims) || - (types & ArrayTypes::MKLDNNDiffDim && shape.ndim() != pd.desc().data.ndims)) { - in_arrs.emplace_back(arr, desc); - InitMKLDNNArray(&in_arrs.back().arr, pd, rand); + if ((types & ArrayTypes::MKLDNN && shape.ndim() == md.data.ndims) || + (types & ArrayTypes::MKLDNNDiffDim && shape.ndim() != md.data.ndims)) { + in_arrs.emplace_back(arr, desc_str); + InitMKLDNNArray(&in_arrs.back().arr, md, rand); } // Type 8, 9. @@ -532,18 +532,18 @@ inline std::vector GetTestOutputArrays( s[0] = shape.Size(); NDArray arr = NDArray(s, Context()); arr = arr.AsArray(shape, arr.dtype()); - InitMKLDNNArray(&arr, pd, rand); - desc = "Reused MKLDNN NDArray"; - if (shape.ndim() != pd.desc().data.ndims) { + InitMKLDNNArray(&arr, md, rand); + desc_str = "Reused MKLDNN NDArray"; + if (shape.ndim() != md.data.ndims) { std::stringstream ss; ss << "Reused MKLDNN NDArray with different memory layout " - << shape.ndim() << "/" << pd.desc().data.ndims; - desc = ss.str(); + << shape.ndim() << "/" << md.data.ndims; + desc_str = ss.str(); } - if ((types & ArrayTypes::MKLDNNReused && shape.ndim() == pd.desc().data.ndims) || - (types & ArrayTypes::MKLDNNReusedDiffDim && shape.ndim() != pd.desc().data.ndims)) { - in_arrs.emplace_back(arr, desc); + if ((types & ArrayTypes::MKLDNNReused && shape.ndim() == md.data.ndims) || + (types & ArrayTypes::MKLDNNReusedDiffDim && shape.ndim() != md.data.ndims)) { + in_arrs.emplace_back(arr, desc_str); } } return in_arrs; @@ -581,9 +581,9 @@ using VerifyFunc = std::function &in_arrs, const std::vector &out_arrs)>; inline void VerifyAddRequest(const std::vector &in_arrs, - const std::vector &original_outputs, - const std::vector &new_outputs, - VerifyFunc verify_fn) { + const std::vector &original_outputs, + const std::vector &new_outputs, + VerifyFunc verify_fn) { CHECK(original_outputs.size() == new_outputs.size()); std::vector tmp_outputs; NDArray tmp; @@ -596,7 +596,7 @@ inline void VerifyAddRequest(const std::vector &in_arrs, } inline void VerifyCopyResult(const std::vector &in_arrs, - const std::vector &out_arrs) { + const std::vector &out_arrs) { NDArray tmp1 = in_arrs[0]->Reorder2Default(); NDArray tmp2 = out_arrs[0]->Reorder2Default(); EXPECT_EQ(tmp1.shape().Size(), tmp2.shape().Size()); @@ -607,7 +607,7 @@ inline void VerifyCopyResult(const std::vector &in_arrs, } inline void VerifySumResult(const std::vector &in_arrs, - const std::vector &out_arrs) { + const std::vector &out_arrs) { NDArray in1 = in_arrs[0]->Reorder2Default(); NDArray in2 = in_arrs[1]->Reorder2Default(); NDArray out = out_arrs[0]->Reorder2Default(); @@ -621,5 +621,5 @@ inline void VerifySumResult(const std::vector &in_arrs, ASSERT_EQ(d1[i] + d2[i], o[i]); } -#endif // MXNET_USE_MKLDNN +#endif // MXNET_USE_MKLDNN == 1 #endif // TEST_MKLDNN_H_ diff --git a/tests/cpp/include/test_op.h b/tests/cpp/include/test_op.h index 67d98c4457e1..172c162e6f15 100644 --- a/tests/cpp/include/test_op.h +++ b/tests/cpp/include/test_op.h @@ -153,9 +153,9 @@ struct OpInfo { /*! \brief The operator data */ std::shared_ptr< OperatorExecutor > executor_; /*! \brief The operator prop class */ - std::shared_ptr prop_; + std::shared_ptr prop_; /*! \brief The input type(s) */ - std::vector in_type_; + std::vector in_type_; }; /*! \brief Pair of op info objects, generally for validating ops against each other */ diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc index 29afb16bdc5b..d26894c21ea7 100644 --- a/tests/cpp/kvstore/gpu_topology_test.cc +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -23,6 +23,8 @@ * \brief gpu topology tests */ +#if MXNET_USE_CUDA + #include #include #include @@ -670,3 +672,5 @@ TEST(GpuTopology, TestKernighanLin2) { << " not equal neither: " << 0 << " nor: " << P.size() << "."; } + +#endif // MXNET_USE_CUDA diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index ed0e70b831f1..74c2b546f161 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -710,12 +710,12 @@ static constexpr size_t CYCLE_COUNT = 3; template static test::op::OpInfoPair testForwardAndBackward( - const bool isGPU1, - const bool isGPU2, - const mxnet::TShape &inputShape, - const test::op::kwargs_t& kwargs, - const size_t count = 1, - const size_t cycleCount = CYCLE_COUNT) { + const bool isGPU1, + const bool isGPU2, + const mxnet::TShape &inputShape, + const test::op::kwargs_t& kwargs, + const size_t count = 1, + const size_t cycleCount = CYCLE_COUNT) { test::op::OpInfo info_1 = TestBatchNormOperatorForward(isGPU1, inputShape, kwargs, count); @@ -1014,14 +1014,14 @@ TEST(BATCH_NORM, TestTiming_2D) { } MSHADOW_REAL_TYPE_SWITCH_EX( mshadow::kFloat32, DType, AccReal, { -#if MXNET_USE_MKLDNN +#if MXNET_USE_MKLDNN == 1 // MKL timingTest>( "MKL BatchNormProp 2D", false, false, blank_kwargs_nocudnn, 2, THISCOUNT); -#endif +#endif // MXNET_USE_MKLDNN == 1 // CPU test::ScopeSet disableMKL(&mxnet::op::batchnorm::disable_mkl, true); timingTest>( diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc index 961785dcfc87..8ae1db6c7712 100644 --- a/tests/cpp/operator/mkldnn_operator_test.cc +++ b/tests/cpp/operator/mkldnn_operator_test.cc @@ -458,7 +458,7 @@ void VerifyConcatResult(const std::vector &in_arrs, } void VerifyConcatBackwardsResult(const std::vector &in_arrs, - const std::vector &out_arrs) { + const std::vector &out_arrs) { // in_arrs is larger array, out_arr is ammler int num_inputs = out_arrs.size(); int input_size = out_arrs[0]->shape().Size(); @@ -491,7 +491,7 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) { std::vector dispatches = attrs.dispatches; TestArrayShapes tas = GetTestArrayShapes(); - std::vector pds = tas.pds; + std::vector mds = tas.mds; if (attrs.requests.find(OpReqType::kWriteTo) != attrs.requests.end()) { std::vector in_arrs = GetTestInputArrays(); @@ -499,7 +499,7 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) { for (auto &dispatch : dispatches) { std::vector> out_arrs(attrs.num_outputs); for (int i = 0; i < attrs.num_outputs; i++) - out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds); + out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), mds); for (int i = 0; i < attrs.num_inputs; i++) inputs[i] = &in_arr.arr; for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { @@ -549,7 +549,7 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) { for (auto &in_arr : in_arrs) { for (auto &dispatch : dispatches) { for (int i = 0; i < attrs.num_outputs; i++) - out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds); + out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), mds); for (size_t i = 0; i < attrs.num_inputs; i++) inputs[i] = &in_arr.arr; for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { @@ -573,14 +573,14 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) { } void TestConcatOp(const OpAttrs &attrs, VerifyFunc verify_fn, - bool backwards = false) { + bool backwards = false) { std::vector inputs(attrs.num_inputs); std::vector outputs(attrs.num_outputs); std::vector req(attrs.num_outputs); std::vector dispatches = attrs.dispatches; TestArrayShapes tas = GetTestArrayShapes(); - std::vector pds = tas.pds; + std::vector mds = tas.mds; std::vector in_arrs = GetTestInputArrays(); @@ -611,7 +611,7 @@ void TestConcatOp(const OpAttrs &attrs, VerifyFunc verify_fn, scale_vector[i] = 1; scale_vector[dim] = scale; for (int i = 0; i < attrs.num_outputs; i++) - out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds, scale_vector); + out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), mds, scale_vector); for (int i = 0; i < attrs.num_inputs; i++) inputs[i] = &in_arr.arr; @@ -678,7 +678,7 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { std::vector req(forward_attrs.num_outputs); TestArrayShapes tas = GetTestArrayShapes(); - std::vector pds = tas.pds; + std::vector mds = tas.mds; std::vector in_arrs = GetTestInputArrays(forward_attrs.input_types, true); std::vector> out_arrs(forward_attrs.num_outputs); @@ -695,9 +695,9 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { for (int i = 0; i < forward_attrs.num_outputs; i++) { out_arrs[i] = - GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, forward_attrs.output_types); + GetTestOutputArrays(in_arr.arr.shape(), mds, {1}, forward_attrs.output_types); ex_out_arrs[i] = - GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, forward_attrs.output_types); + GetTestOutputArrays(in_arr.arr.shape(), mds, {1}, forward_attrs.output_types); } for (int i = 0; i < forward_attrs.num_inputs; i++) @@ -806,7 +806,7 @@ void TestOpExBNBackward(const OpAttrs &forward_attrs, Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs, backwards_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr()); Engine::Get()->WaitForAll(); - AssertEqual(backwards_outputs, backwards_ex_outputs); + AssertEqual(backwards_outputs, backwards_ex_outputs, 1e-4, 1e-2); } } @@ -821,7 +821,7 @@ void TestOpExBN(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { std::vector req(forward_attrs.num_outputs); TestArrayShapes tas = GetTestArrayShapes(); - std::vector pds = tas.pds; + std::vector mds = tas.mds; std::vector in_arrs = GetTestInputArrays(forward_attrs.input_types, false); std::vector> out_arrs(forward_attrs.num_outputs); @@ -837,9 +837,9 @@ void TestOpExBN(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { continue; for (int i = 0; i < forward_attrs.num_outputs; i++) { out_arrs[i] = - GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types); + GetTestOutputArrays(in_arr.arr.shape(), mds, {1}, true, forward_attrs.output_types); ex_out_arrs[i] = - GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types); + GetTestOutputArrays(in_arr.arr.shape(), mds, {1}, true, forward_attrs.output_types); } for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { inputs_buffer.clear(); @@ -867,11 +867,11 @@ void TestOpExBN(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { Context(), forward_attrs.attrs, inputs2, ex_outputs, req, DispatchMode::kFComputeEx, mxnet::OpStatePtr()); Engine::Get()->WaitForAll(); - AssertEqual(outputs, ex_outputs); + AssertEqual(outputs, ex_outputs, 1e-04, 1e-02); if (!backwards_attrs.requests.empty()) { TestOpExBNBackward(forward_attrs, backwards_attrs, OpReqType::kWriteTo, - inputs, outputs, in_arr, &out_arrs[0][output_i]); + inputs, outputs, in_arr, &out_arrs[0][output_i]); } } } @@ -900,7 +900,7 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards std::vector back_req(backwards_attrs.num_outputs); TestArrayShapes tas = GetTestArrayShapes(); - std::vector pds = tas.pds; + std::vector mds = tas.mds; std::vector in_arrs = GetTestInputArrays(forward_attrs.input_types, true); std::vector> out_arrs(forward_attrs.num_outputs); @@ -937,9 +937,9 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards for (int i = 0; i < forward_attrs.num_outputs; i++) { out_arrs[i] = - GetTestOutputArrays(out_shape, pds, {1}, forward_attrs.output_types); + GetTestOutputArrays(out_shape, mds, {1}, forward_attrs.output_types); ex_out_arrs[i] = - GetTestOutputArrays(out_shape, pds, {1}, forward_attrs.output_types); + GetTestOutputArrays(out_shape, mds, {1}, forward_attrs.output_types); } for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { @@ -1014,7 +1014,7 @@ void TestConvOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs, std::vector dispatches = forward_attrs.dispatches; TestArrayShapes tas = GetTestArrayShapes(); - std::vector pds = tas.pds; + std::vector mds = tas.mds; P param; param.Init(forward_attrs.attrs.dict); @@ -1050,9 +1050,9 @@ void TestConvOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs, scale_vector[3] = scale; for (size_t i = 0; i < forward_attrs.num_outputs; ++i) { - out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds, + out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), mds, scale_vector, true, forward_attrs.output_types); - ex_out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds, + ex_out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), mds, scale_vector, true, forward_attrs.output_types); } NDArray ndkernel = CreateKernelNDArray(kernel, num_filter, in_arr.arr.shape(), is_deconv); @@ -1140,7 +1140,7 @@ void TestPoolingOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) std::vector dispatches = forward_attrs.dispatches; TestArrayShapes tas = GetTestArrayShapes(); - std::vector pds = tas.pds; + std::vector mds = tas.mds; mxnet::op::PoolingParam param; param.Init(forward_attrs.attrs.dict); @@ -1160,7 +1160,7 @@ void TestPoolingOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) if (input_shape.ndim() != kernel.ndim() + 2) continue; // cannot pool if ndarray and mkldnn memory have different ndim - if (in_arr.arr.IsView() || in_arr.arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims + if (in_arr.arr.IsView() || in_arr.arr.GetMKLDNNData()->get_desc().data.ndims != in_arr.arr.shape().ndim()) continue; std::vector scale_vector(in_arr.arr.shape().ndim()); @@ -1173,8 +1173,8 @@ void TestPoolingOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) static_cast(input_shape[i]); } for (int i = 0; i < forward_attrs.num_outputs; i++) { - out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds, scale_vector); - ex_out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds, scale_vector); + out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), mds, scale_vector); + ex_out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), mds, scale_vector); } for (int i = 0; i < forward_attrs.num_inputs; i++) @@ -1353,4 +1353,4 @@ TEST(IMPERATIVE, BNOp) { TestOpExBN(forward_attrs, backwards_attrs); } -#endif +#endif // MXNET_USE_MKLDNN == 1 diff --git a/tests/cpp/operator/mkldnn_test.cc b/tests/cpp/operator/mkldnn_test.cc index ab624e3a3c44..bcdb38ac4aa8 100644 --- a/tests/cpp/operator/mkldnn_test.cc +++ b/tests/cpp/operator/mkldnn_test.cc @@ -88,10 +88,10 @@ TEST(MKLDNN_UTIL_FUNC, AlignMem) { } static void VerifyDefMem(const mkldnn::memory &mem) { - mkldnn::memory::primitive_desc pd = mem.get_primitive_desc(); + mkldnn::memory::desc desc = mem.get_desc(); mshadow::default_real_t *data = static_cast(mem.get_data_handle()); - size_t size = pd.get_size() / sizeof(mshadow::default_real_t); + size_t size = desc.get_size() / sizeof(mshadow::default_real_t); size_t num_same = 0; for (int i = 0; i < size; i++) num_same += data[i] == static_cast(i % 100 - 50); @@ -100,29 +100,30 @@ static void VerifyDefMem(const mkldnn::memory &mem) { TEST(MKLDNN_UTIL_FUNC, MemFormat) { // Check whether the number of format is correct. - CHECK_EQ(mkldnn_format_last, 158); - CHECK_EQ(mkldnn_nchw, 7); - CHECK_EQ(mkldnn_oihw, 17); + CHECK_EQ(mkldnn_format_tag_last, 131); + CHECK_EQ(mkldnn_nchw, 5); + CHECK_EQ(mkldnn_oihw, 5); } static void VerifyMem(const mkldnn::memory &mem) { - mkldnn::memory::primitive_desc pd = mem.get_primitive_desc(); + mkldnn::memory::desc desc = mem.get_desc(); + mkldnn::memory::dims dims(desc.data.ndims); + for (size_t i = 0; i < dims.size(); i++) + dims[i] = desc.data.dims[i]; + mkldnn::memory::desc new_desc{dims, + static_cast(desc.data.data_type), + static_cast(GetDefaultFormat(desc))}; - if (pd.desc().data.format == GetDefaultFormat(pd.desc())) { + if (desc == new_desc) { VerifyDefMem(mem); } else { - mkldnn::memory::dims dims(pd.desc().data.ndims); - for (size_t i = 0; i < dims.size(); i++) - dims[i] = pd.desc().data.dims[i]; - mkldnn::memory::desc desc{dims, - static_cast(pd.desc().data.data_type), - static_cast(GetDefaultFormat(pd.desc()))}; - mkldnn::memory::primitive_desc new_pd(desc, CpuEngine::Get()->get_engine()); - mkldnn::memory new_mem(new_pd); - - std::vector net; - net.push_back(mkldnn::reorder(mem, new_mem)); - mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + mkldnn::memory* src_mem = const_cast(&mem); + mkldnn::memory new_mem(new_desc, CpuEngine::Get()->get_engine()); + + mkldnn::stream s(CpuEngine::Get()->get_engine()); + mkldnn::reorder(*src_mem, new_mem) + .execute(s, *src_mem, new_mem); + VerifyDefMem(new_mem); } } @@ -130,23 +131,23 @@ static void VerifyMem(const mkldnn::memory &mem) { TEST(MKLDNN_NDArray, GetDataReorder) { TestArrayShapes tas = GetTestArrayShapes(); mxnet::ShapeVector shapes = tas.shapes; - std::vector pds = tas.pds; + std::vector mds = tas.mds; // Reorder from the default to any other layout. for (auto s : shapes) { NDArray arr(s, Context()); InitDefaultArray(&arr); - for (auto pd : pds) { - if (s.Size() == pd.get_size() / sizeof(mshadow::default_real_t)) { - const mkldnn::memory *mem = arr.GetMKLDNNDataReorder(pd); + for (auto md : mds) { + if (s.Size() == md.get_size() / sizeof(mshadow::default_real_t)) { + const mkldnn::memory *mem = arr.GetMKLDNNDataReorder(md); printf("reorder from ("); for (size_t i = 0; i < s.ndim(); i++) printf("%ld, ", s[i]); printf(") to ("); - for (int i = 0; i < pd.desc().data.ndims; i++) - printf("%d, ", pd.desc().data.dims[i]); - printf("), format: %d\n", pd.desc().data.format); + for (int i = 0; i < md.data.ndims; i++) + printf("%ld, ", md.data.dims[i]); + printf("), format: %d\n", static_cast(GetDefaultFormat(md))); MKLDNNStream::Get()->Submit(false); VerifyMem(*mem); MKLDNNStream::Get()->Cleanup(); @@ -156,8 +157,8 @@ TEST(MKLDNN_NDArray, GetDataReorder) { // Reorder from a special layout to another layout. for (auto s : shapes) { - for (auto from_pd : pds) { - if (from_pd.get_size() / sizeof(mshadow::default_real_t) == s.Size()) { + for (auto md : mds) { + if (md.get_size() / sizeof(mshadow::default_real_t) == s.Size()) { NDArray arr(s, Context()); // There is possibility that the dimensions of an NDArray doesn't match // with the MKLDNN memory inside. @@ -165,21 +166,20 @@ TEST(MKLDNN_NDArray, GetDataReorder) { for (size_t i = 0; i < s.ndim(); i++) printf("%ld, ", s[i]); printf(") with MKLDNN memory ("); - for (int i = 0; i < from_pd.desc().data.ndims; i++) - printf("%d, ", from_pd.desc().data.dims[i]); - printf("), format: %d\n", from_pd.desc().data.format); - InitMKLDNNArray(&arr, from_pd); - for (auto to_pd : pds) { - if (to_pd.get_size() / sizeof(mshadow::default_real_t) == s.Size()) { - const mkldnn::memory *mem = arr.GetMKLDNNDataReorder(to_pd); + for (int i = 0; i < md.data.ndims; i++) + printf("%ld, ", md.data.dims[i]); + printf("), format: %d\n", static_cast(GetDefaultFormat(md))); + InitMKLDNNArray(&arr, md); + for (auto to_md : mds) { + if (to_md.get_size() / sizeof(mshadow::default_real_t) == s.Size()) { + const mkldnn::memory *mem = arr.GetMKLDNNDataReorder(to_md); printf("reorder from ("); for (size_t i = 0; i < s.ndim(); i++) printf("%ld, ", s[i]); - printf("), format: %d to (", - arr.GetMKLDNNData()->get_primitive_desc().desc().data.format); - for (int i = 0; i < to_pd.desc().data.ndims; i++) - printf("%d, ", to_pd.desc().data.dims[i]); - printf("), format: %d\n", to_pd.desc().data.format); + printf("), format: %d to (", static_cast(GetDefaultFormat(to_md))); + for (int i = 0; i < to_md.data.ndims; i++) + printf("%ld, ", to_md.data.dims[i]); + printf("), format: %d\n", static_cast(GetDefaultFormat(to_md))); MKLDNNStream::Get()->Submit(false); VerifyMem(*mem); MKLDNNStream::Get()->Cleanup(); @@ -194,7 +194,7 @@ TEST(MKLDNN_BASE, MKLDNNSum) { std::vector in_arrs = GetTestInputArrays(); std::vector in_arrs2 = GetTestInputArrays(ArrayTypes::All, true); TestArrayShapes tas = GetTestArrayShapes(); - std::vector pds = tas.pds; + std::vector mds = tas.mds; for (int i = 0; i < in_arrs.size(); i++) { auto in_arr = in_arrs[i]; @@ -204,7 +204,7 @@ TEST(MKLDNN_BASE, MKLDNNSum) { if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView()) { continue; } - std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds); + std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), mds); for (auto &out_arr : out_arrs) { auto in_mem1 = in_arr.arr.GetMKLDNNData(); auto in_mem2 = in_arr2.arr.GetMKLDNNData(); @@ -232,7 +232,7 @@ TEST(MKLDNN_BASE, MKLDNNSum) { NDArrayAttrs orig_arr(in_arr.arr.Copy(in_arr.arr.ctx()), "In Place Copy"); orig_arr.arr.WaitToRead(); PrintVerifyMsg(orig_arr, in_arr); - InitMKLDNNArray(&orig_arr.arr, input_mem->get_primitive_desc()); + InitMKLDNNArray(&orig_arr.arr, input_mem->get_desc()); orig_arr.arr.CopyFrom(*input_mem); op::MKLDNNSum(*input_mem, *input_mem2, *input_mem); MKLDNNStream::Get()->Submit(); @@ -244,7 +244,7 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) { std::vector in_arrs = GetTestInputArrays(); std::vector in_arrs2 = GetTestInputArrays(ArrayTypes::All, true); TestArrayShapes tas = GetTestArrayShapes(); - std::vector pds = tas.pds; + std::vector mds = tas.mds; MKLDNNStream *stream = MKLDNNStream::Get(); // kWriteTo @@ -256,7 +256,7 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) { if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView()) { continue; } - std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds); + std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), mds); for (auto &out_arr : out_arrs) { auto in_mem = in_arr.arr.GetMKLDNNData(); auto in_mem2 = in_arr2.arr.GetMKLDNNData(); @@ -264,7 +264,7 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) { orig_output.WaitToRead(); PrintVerifyMsg(in_arr, out_arr); auto out_mem = out_arr.arr.GetMKLDNNData(); - auto output_mem_t = CreateMKLDNNMem(out_arr.arr, out_mem->get_primitive_desc(), kWriteTo); + auto output_mem_t = CreateMKLDNNMem(out_arr.arr, out_mem->get_desc(), kWriteTo); op::MKLDNNSum(*in_mem, *in_mem2, *output_mem_t.second); CommitOutput(out_arr.arr, output_mem_t); stream->Submit(); @@ -286,10 +286,10 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) { NDArrayAttrs orig_arr(in_arr.arr.Copy(in_arr.arr.ctx()), "In Place Copy"); orig_arr.arr.WaitToRead(); PrintVerifyMsg(orig_arr, in_arr); - InitMKLDNNArray(&orig_arr.arr, input_mem->get_primitive_desc()); + InitMKLDNNArray(&orig_arr.arr, input_mem->get_desc()); orig_arr.arr.CopyFrom(*input_mem); auto output_mem_t = CreateMKLDNNMem(in_arr.arr, - input_mem->get_primitive_desc(), kWriteInplace, &in_arr.arr); + input_mem->get_desc(), kWriteInplace, &in_arr.arr); op::MKLDNNSum(*input_mem, *input_mem2, *output_mem_t.second); CommitOutput(in_arr.arr, output_mem_t); stream->Submit(); @@ -305,7 +305,7 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) { if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView()) { continue; } - std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds); + std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), mds); for (auto &out_arr : out_arrs) { auto in_mem = in_arr.arr.GetMKLDNNData(); auto in_mem2 = in_arr2.arr.GetMKLDNNData(); @@ -313,7 +313,7 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) { orig_output.WaitToRead(); PrintVerifyMsg(in_arr, out_arr); auto out_mem = out_arr.arr.GetMKLDNNData(); - auto output_mem_t = CreateMKLDNNMem(out_arr.arr, out_mem->get_primitive_desc(), kAddTo); + auto output_mem_t = CreateMKLDNNMem(out_arr.arr, out_mem->get_desc(), kAddTo); op::MKLDNNSum(*in_mem, *in_mem2, *output_mem_t.second); CommitOutput(out_arr.arr, output_mem_t); stream->Submit(); @@ -336,9 +336,9 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) { NDArrayAttrs orig_arr(in_arr.arr.Copy(in_arr.arr.ctx()), "In Place Copy"); orig_arr.arr.WaitToRead(); PrintVerifyMsg(orig_arr, in_arr); - InitMKLDNNArray(&orig_arr.arr, input_mem->get_primitive_desc()); + InitMKLDNNArray(&orig_arr.arr, input_mem->get_desc()); orig_arr.arr.CopyFrom(*input_mem); - auto output_mem_t = CreateMKLDNNMem(in_arr.arr, input_mem->get_primitive_desc(), kNullOp); + auto output_mem_t = CreateMKLDNNMem(in_arr.arr, input_mem->get_desc(), kNullOp); op::MKLDNNSum(*input_mem, *input_mem2, *output_mem_t.second); CommitOutput(in_arr.arr, output_mem_t); stream->Submit(); @@ -373,8 +373,8 @@ TEST(MKLDNN_NDArray, GetTestInputArraysConcat) { TEST(MKLDNN_NDArray, GetTestOutputArraysConcat) { auto shapes_pds = GetTestArrayShapes(); - std::vector shapes; shapes = shapes_pds.shapes; - std::vector pds = shapes_pds.pds; + std::vector shapes = shapes_pds.shapes; + std::vector mds = shapes_pds.mds; for (auto &shape : shapes) { for (int dim = 0; dim < 5; dim++) { for (int num_inputs = 2; num_inputs < 5; num_inputs++) { @@ -386,7 +386,7 @@ TEST(MKLDNN_NDArray, GetTestOutputArraysConcat) { for (int i = 0; i < shape.ndim(); i++) scale_vector[i] = 1; scale_vector[dim] = num_inputs; - auto output_arrs = GetTestOutputArrays(shape, pds, scale_vector); + auto output_arrs = GetTestOutputArrays(shape, mds, scale_vector); for (auto &out_arr : output_arrs) { auto out_shape = out_arr.arr.shape(); EXPECT_EQ(shape.Size() * num_inputs, out_arr.arr.shape().Size()); @@ -399,13 +399,13 @@ TEST(MKLDNN_NDArray, GetTestOutputArraysConcat) { TEST(MKLDNN_NDArray, CopyFrom) { TestArrayShapes tas = GetTestArrayShapes(); - std::vector pds = tas.pds; + std::vector mds = tas.mds; std::vector in_arrs = GetTestInputArrays(); for (auto &in_arr : in_arrs) { if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView()) continue; - std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds); + std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), mds); for (auto &out_arr : out_arrs) { const mkldnn::memory *mem = in_arr.arr.GetMKLDNNData(); out_arr.arr.CopyFrom(*mem); @@ -417,4 +417,4 @@ TEST(MKLDNN_NDArray, CopyFrom) { } } -#endif +#endif // MXNET_USE_MKLDNN == 1 diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index f88c0a888320..e43daf12c464 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -568,5 +568,73 @@ def test_weight_async_reorder(): for output in mod.get_outputs(): output.wait_to_read() +@with_seed() +def test_concat(): + def ref_concat(a, b, axis): + return np.concatenate((a, b), axis=axis) + + a_sym = mx.sym.Variable("a") + b_sym = mx.sym.Variable("b") + dshape = rand_shape_nd(4) + a_shape = tuple(dshape) + b_shape = tuple(dshape) + + for axis in range(0, 4): + z = mx.sym.concat(a_sym, b_sym, dim=axis) + a = np.random.uniform(-1, 1, a_shape) + b = np.random.uniform(-1, 1, b_shape) + exe = z.simple_bind(ctx=mx.cpu(), a=a_shape, b=b_shape) + out = exe.forward(is_train=False, a=a, b=b) + ref_out = ref_concat(a, b, axis=axis) + out = out[0].asnumpy() + assert_almost_equal(out, ref_out) + + def check_concat_training(stype): + data_shape = rand_shape_nd(4) + for density in [1.0, 0.5, 0.0]: + a_sym = mx.sym.Variable('a') + b_sym = mx.sym.Variable('b') + sym = mx.sym.concat(a_sym, b_sym, dim=1) + a = rand_ndarray(shape=data_shape, stype=stype, density=density) + b = rand_ndarray(shape=data_shape, stype=stype, density=density) + in_location = [a, b] + check_numeric_gradient(sym, in_location, numeric_eps=1e-3, rtol=1e-3, atol=5e-3) + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_concat_training(stype) + +@with_seed() +def test_elemwise_add(): + def ref_add(a, b): + return np.add(a, b) + + a_sym = mx.sym.Variable("a") + b_sym = mx.sym.Variable("b") + dshape = rand_shape_nd(4) + a_shape = tuple(dshape) + b_shape = tuple(dshape) + z = mx.sym.elemwise_add(a_sym, b_sym) + a = np.random.uniform(-1, 1, a_shape) + b = np.random.uniform(-1, 1, b_shape) + exe = z.simple_bind(ctx=mx.cpu(), a=a_shape, b=b_shape) + out = exe.forward(is_train=False, a=a, b=b) + ref_out = ref_add(a, b) + out = out[0].asnumpy() + assert_almost_equal(out, ref_out, rtol=1e-6, atol=1e-6) + + def check_elemwise_add_training(stype): + data_shape = rand_shape_nd(4) + for density in [1.0, 0.5, 0.0]: + a_sym = mx.sym.Variable('a') + b_sym = mx.sym.Variable('b') + sym = mx.sym.elemwise_add(a_sym, b_sym) + a = rand_ndarray(shape=data_shape, stype=stype, density=density) + b = rand_ndarray(shape=data_shape, stype=stype, density=density) + in_location = [a, b] + check_numeric_gradient(sym, in_location, numeric_eps=1e-3, rtol=1e-3, atol=5e-3) + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_elemwise_add_training(stype) + if __name__ == '__main__': install.test_mkldnn_install() diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 0a9fcda56a64..6d6ba41ab7fb 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -407,7 +407,7 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p def test_quantized_fc(): def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): if is_test_for_native_cpu(): - hasMKL = False; + hasMKL = False for key in os.environ.keys(): if operator.eq(key, "BUILD_TAG"): if os.environ['BUILD_TAG'].find("MKL") != -1: @@ -617,12 +617,11 @@ def check_quantized_bn(data_shape, qdtype): # qdtype = uint8 if qdtype == 'uint8': data_low = 0.0 - data_high = 127.0 + data_high = 255.0 else: data_low = -127.0 data_high = 127.0 - # output type = int8 - quantized_range = 127.0 + # run fp32 bn data_sym = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') bn_fp32 = mx.sym.BatchNorm(data=data_sym, name='bn', use_global_stats=True, fix_gamma=False) @@ -653,12 +652,12 @@ def check_quantized_bn(data_shape, qdtype): calib_data = NDArrayIter(data=data, batch_size=data_shape[0]) calib_data = DummyIter(calib_data) - # quantize bn with quantized_type = int8: MKLDNN BN only support int8 output qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=bn_fp32, arg_params=arg_params, aux_params=bn_fp32_exe.aux_dict, ctx=mx.current_context(), - quantized_dtype='int8', + quantized_dtype=qdtype, + quantize_mode='full', calib_mode='naive', calib_data=calib_data, num_calib_examples=20) @@ -670,7 +669,7 @@ def check_quantized_bn(data_shape, qdtype): mod.forward(batch, is_train=False) output_int8_to_fp32 = mod.get_outputs()[0] - assert_almost_equal(output.asnumpy(), output_int8_to_fp32.asnumpy(), rtol=1e-1, atol=4) + assert_almost_equal(output.asnumpy(), output_int8_to_fp32.asnumpy(), rtol=1e-1, atol=8) for qdtype in ['int8', 'uint8']: check_quantized_bn((32, 512, 4, 4), qdtype) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 33f739bd10fc..376e177d0659 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -36,6 +36,15 @@ import os def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, atol=1e-4): + if default_context().device_type == 'cpu': + # NOTE(zixuanweeei): Currently, we don't add `add` requests support on fused mkl-dnn rnn operator. + # We tracked this issue by https://github.com/apache/incubator-mxnet/issues/16578 + if isinstance(grad_req, dict) and 'add' in grad_req.values(): + print("Skip the test when requiring `add` operation against gradients on CPU context.") + return + if isinstance(grad_req, str) and grad_req == 'add': + print("Skip the test when requiring `add` operation against gradients on CPU context.") + return dshape = (N, T, I) data = mx.sym.Variable('data') @@ -86,7 +95,7 @@ def test_rnn_with_new_param(): for mode, ngates in zip(rnn_modes, ngates_): first_layer_size = (input_size * state_size + state_size * state_size + state_size * 2) * ngates rest_layer_size = (state_size * directions * state_size + state_size * state_size + state_size * 2) \ - * ngates * (num_layers - 1) + * ngates * (num_layers - 1) param_size = (first_layer_size + rest_layer_size) * directions sym = mx.sym.RNN(mode=mode, num_layers=num_layers, bidirectional=bidirectional, state_outputs=False, state_size=state_size, name='rnn') @@ -118,112 +127,133 @@ def test_rnn_with_new_param(): @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_lstm_sym(): - T, N, I, H = 5, 32, 800, 800 - fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='lstm', get_next_state=True, prefix='') - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.LSTMCell(H, prefix='l0_')) - stack.add(mx.rnn.LSTMCell(H, prefix='l1_')) - stack.add(mx.rnn.LSTMCell(H, prefix='l2_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='lstm', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.LSTMCell(H, prefix='l0_')) + stack.add(mx.rnn.LSTMCell(H, prefix='l1_')) + stack.add(mx.rnn.LSTMCell(H, prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_lstm_bidirectional(): - T, N, I, H = 5, 20, 800, 800 - fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='lstm', - bidirectional=True, get_next_state=True, prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.LSTMCell(H, prefix='l0_'), - mx.rnn.LSTMCell(H, prefix='r0_'), - output_prefix='bi_lstm_0_')) - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.LSTMCell(H, prefix='l1_'), - mx.rnn.LSTMCell(H, prefix='r1_'), - output_prefix='bi_lstm_1_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') - check_rnn_consistency(fused, stack, T, N, I, H, {'data': 'add', 'parameters': 'null'}) + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='lstm', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.LSTMCell(H, prefix='l0_'), + mx.rnn.LSTMCell(H, prefix='r0_'), + output_prefix='bi_lstm_0_')) + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.LSTMCell(H, prefix='l1_'), + mx.rnn.LSTMCell(H, prefix='r1_'), + output_prefix='bi_lstm_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') + check_rnn_consistency(fused, stack, T, N, I, H, {'data': 'add', 'parameters': 'null'}) @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_gru_sym(): - T, N, I, H = 5, 32, 800, 800 - fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='gru', get_next_state=True, prefix='') - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.GRUCell(H, prefix='l0_')) - stack.add(mx.rnn.GRUCell(H, prefix='l1_')) - stack.add(mx.rnn.GRUCell(H, prefix='l2_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='gru', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.GRUCell(H, prefix='l0_')) + stack.add(mx.rnn.GRUCell(H, prefix='l1_')) + stack.add(mx.rnn.GRUCell(H, prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write', atol=2e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', atol=2e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', atol=2e-4) @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_gru_bidirectional(): - T, N, I, H = 5, 20, 800, 800 - - fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='gru', - bidirectional=True, get_next_state=True, prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.GRUCell(H, prefix='l0_'), - mx.rnn.GRUCell(H, prefix='r0_'), - output_prefix='bi_gru_0_')) - - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.GRUCell(H, prefix='l1_'), - mx.rnn.GRUCell(H, prefix='r1_'), - output_prefix='bi_gru_1_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='gru', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.GRUCell(H, prefix='l0_'), + mx.rnn.GRUCell(H, prefix='r0_'), + output_prefix='bi_gru_0_')) + + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.GRUCell(H, prefix='l1_'), + mx.rnn.GRUCell(H, prefix='r1_'), + output_prefix='bi_gru_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write', atol=2e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', atol=2e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', atol=2e-4) @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnntanh_sym(): - T, N, I, H = 5, 32, 800, 800 - - fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_tanh', get_next_state=True, prefix='') - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l0_')) - stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l1_')) - stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l2_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_tanh', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l0_')) + stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l1_')) + stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnntanh_bidirectional(): - T, N, I, H = 5, 20, 800, 800 - - fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_tanh', - bidirectional=True, get_next_state=True, prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.RNNCell(H, activation='tanh', prefix='l0_'), - mx.rnn.RNNCell(H, activation='tanh', prefix='r0_'), - output_prefix='bi_rnntanh_0_')) - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.RNNCell(H, activation='tanh', prefix='l1_'), - mx.rnn.RNNCell(H, activation='tanh', prefix='r1_'), - output_prefix='bi_rnntanh_1_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_tanh', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='tanh', prefix='l0_'), + mx.rnn.RNNCell(H, activation='tanh', prefix='r0_'), + output_prefix='bi_rnntanh_0_')) + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='tanh', prefix='l1_'), + mx.rnn.RNNCell(H, activation='tanh', prefix='r1_'), + output_prefix='bi_rnntanh_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') @@ -243,24 +273,27 @@ def test_rnnrelu_sym(): @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnnrelu_bidirectional(): - T, N, I, H = 5, 20, 200, 200 - - fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_relu', - bidirectional=True, get_next_state=True, prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.RNNCell(H, activation='relu', prefix='l0_'), - mx.rnn.RNNCell(H, activation='relu', prefix='r0_'), - output_prefix='bi_rnnrelu_0_')) - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.RNNCell(H, activation='relu', prefix='l1_'), - mx.rnn.RNNCell(H, activation='relu', prefix='r1_'), - output_prefix='bi_rnnrelu_1_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write', rtol=1e-2, atol=1e-2) - check_rnn_consistency(fused, stack, T, N, I, H, 'add', rtol=1e-2, atol=1e-2) - check_rnn_consistency(fused, stack, T, N, I, H, 'null', rtol=1e-2, atol=1e-2) + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_relu', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='relu', prefix='l0_'), + mx.rnn.RNNCell(H, activation='relu', prefix='r0_'), + output_prefix='bi_rnnrelu_0_')) + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='relu', prefix='l1_'), + mx.rnn.RNNCell(H, activation='relu', prefix='r1_'), + output_prefix='bi_rnnrelu_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write', rtol=1e-2, atol=1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', rtol=1e-2, atol=1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', rtol=1e-2, atol=1e-2) @with_seed() def test_lstm_dropout(): diff --git a/tools/pip/setup.py b/tools/pip/setup.py index 82aa632d28ad..dd430f5a6f87 100644 --- a/tools/pip/setup.py +++ b/tools/pip/setup.py @@ -147,19 +147,11 @@ def has_ext_modules(self): 'dmlc_tracker': []} if variant.endswith('MKL'): if platform.system() == 'Darwin': - shutil.copy(os.path.join(os.path.dirname(LIB_PATH[0]), 'libmklml.dylib'), os.path.join(CURRENT_DIR, 'mxnet')) - shutil.copy(os.path.join(os.path.dirname(LIB_PATH[0]), 'libiomp5.dylib'), os.path.join(CURRENT_DIR, 'mxnet')) - shutil.copy(os.path.join(os.path.dirname(LIB_PATH[0]), 'libmkldnn.0.dylib'), os.path.join(CURRENT_DIR, 'mxnet')) - package_data['mxnet'].append('mxnet/libmklml.dylib') - package_data['mxnet'].append('mxnet/libiomp5.dylib') - package_data['mxnet'].append('mxnet/libmkldnn.0.dylib') + shutil.copy(os.path.join(os.path.dirname(LIB_PATH[0]), 'libmkldnn.1.dylib'), os.path.join(CURRENT_DIR, 'mxnet')) + package_data['mxnet'].append('mxnet/libmkldnn.1.dylib') else: - shutil.copy(os.path.join(os.path.dirname(LIB_PATH[0]), 'libmklml_intel.so'), os.path.join(CURRENT_DIR, 'mxnet')) - shutil.copy(os.path.join(os.path.dirname(LIB_PATH[0]), 'libiomp5.so'), os.path.join(CURRENT_DIR, 'mxnet')) - shutil.copy(os.path.join(os.path.dirname(LIB_PATH[0]), 'libmkldnn.so.0'), os.path.join(CURRENT_DIR, 'mxnet')) - package_data['mxnet'].append('mxnet/libmklml_intel.so') - package_data['mxnet'].append('mxnet/libiomp5.so') - package_data['mxnet'].append('mxnet/libmkldnn.so.0') + shutil.copy(os.path.join(os.path.dirname(LIB_PATH[0]), 'libmkldnn.so.1'), os.path.join(CURRENT_DIR, 'mxnet')) + package_data['mxnet'].append('mxnet/libmkldnn.so.1') shutil.copytree(os.path.join(CURRENT_DIR, 'mxnet-build/3rdparty/mkldnn/build/install/include'), os.path.join(CURRENT_DIR, 'mxnet/include/mkldnn')) if platform.system() == 'Linux': diff --git a/tools/staticbuild/build_lib.sh b/tools/staticbuild/build_lib.sh index 927c15d1dabc..4a82b80d00ba 100755 --- a/tools/staticbuild/build_lib.sh +++ b/tools/staticbuild/build_lib.sh @@ -35,20 +35,14 @@ $MAKE DEPS_PATH=$DEPS_PATH PSLITE if [[ $VARIANT == *mkl ]]; then if [[ $PLATFORM == 'linux' ]]; then - IOMP_LIBFILE='libiomp5.so' - MKLML_LIBFILE='libmklml_intel.so' - MKLDNN_LIBFILE='libmkldnn.so.0' + MKLDNN_LIBFILE='libmkldnn.so.1' else - IOMP_LIBFILE='libiomp5.dylib' - MKLML_LIBFILE='libmklml.dylib' - MKLDNN_LIBFILE='libmkldnn.0.dylib' + MKLDNN_LIBFILE='libmkldnn.1.dylib' fi $MAKE DEPS_PATH=$DEPS_PATH mkldnn if [ ! -d lib ]; then mkdir lib fi - cp 3rdparty/mkldnn/build/install/lib/$IOMP_LIBFILE lib - cp 3rdparty/mkldnn/build/install/lib/$MKLML_LIBFILE lib cp 3rdparty/mkldnn/build/install/lib/$MKLDNN_LIBFILE lib fi