diff --git a/.asf.yaml b/.asf.yaml index 34e813f39639..7fd3f6930fb1 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -32,3 +32,17 @@ github: - vulkan - spirv - machine-learning + + # Triage perm for collaborators(test run) + # + # The perm is given based on needs and not based on + # evaluation of past contributions. The rationale + # is that people may need the permission to start + # contributing in this way. It serves to diversify + # the ways to contribute. + # + # There is a limited number of slots. To enable broad + # participation, permission is given on a three month + # cycle. PMC may review and recycle slots when necessary. + collaborators: + - denise-k diff --git a/CMakeLists.txt b/CMakeLists.txt index c40b0c878905..24f0653b3a78 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -693,7 +693,7 @@ if(USE_CCACHE) # True for AUTO, ON, /path/to/ccache message(STATUS "Found the path to ccache, enabling ccache") set(PATH_TO_CCACHE ccache) else() - message(FATAL_ERROR "Cannot find ccache. Set USE_CCACHE mode to AUTO or OFF to build without ccache. USE_CCACHE=" "${USE_CCACHE") + message(FATAL_ERROR "Cannot find ccache. Set USE_CCACHE mode to AUTO or OFF to build without ccache. USE_CCACHE=" "${USE_CCACHE}") endif(CCACHE_FOUND) else() # /path/to/ccache set(PATH_TO_CCACHE USE_CCACHE) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index b9ef0479c72f..6c63793fa217 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -111,12 +111,14 @@ We do encourage everyone to work anything they are interested in. - [Andrew Z. Luo](https://github.com/AndrewZhaoLuo): @AndrewZhaoLuo - [Steven Lyubomirsky](https://github.com/slyubomirsky): @slyubomirsky - [Masahiro Masuda](https://github.com/masahi): @masahi +- [Andrey Malyshev](https://github.com/elvin-n): @elvin-n - [Sergey Mironov](https://github.com/grwlf): @grwlf - [Thierry Moreau](https://github.com/tmoreau89): @tmoreau89 - [Kazutaka Morita](https://github.com/kazum): @kazum - [Trevor Morris](https://github.com/trevor-m): @trevor-m - [Tatsuya Nishiyama](https://github.com/nishi-t): @nishi-t - [Leandro Nunes](https://github.com/leandron): @leandron +- [Jiawei Liu](https://github.com/ganler): @ganler - [Lily Orth-Smith](https://github.com/electriclilies): @electriclilies - [Wei Pan](https://github.com/wpan11nv): @wpan11nv - [Krzysztof Parzyszek](https://github.com/kparzysz-quic): @kparzysz-quic diff --git a/README.md b/README.md index 09ceb7ab1d07..d96038d17804 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ TVM is licensed under the [Apache-2.0](LICENSE) license. Getting Started --------------- Check out the [TVM Documentation](https://tvm.apache.org/docs/) site for installation instructions, tutorials, examples, and more. -The [Getting Started with TVM](https://tvm.apache.org/docs/tutorials/get_started/introduction.html) tutorial is a great +The [Getting Started with TVM](https://tvm.apache.org/docs/tutorial/introduction.html) tutorial is a great place to start. Contribute to TVM diff --git a/apps/bundle_deploy/crt_config/crt_config.h b/apps/bundle_deploy/crt_config/crt_config.h index b89bedbc6d45..3adcb2dc8d42 100644 --- a/apps/bundle_deploy/crt_config/crt_config.h +++ b/apps/bundle_deploy/crt_config/crt_config.h @@ -37,7 +37,9 @@ /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/src/runtime/hexagon/launcher/README.md b/apps/hexagon_launcher/README.md similarity index 57% rename from src/runtime/hexagon/launcher/README.md rename to apps/hexagon_launcher/README.md index a8a570918514..b190dd81a7b2 100644 --- a/src/runtime/hexagon/launcher/README.md +++ b/apps/hexagon_launcher/README.md @@ -19,9 +19,7 @@ ## Compilation The launcher consists of two parts: part running on Hexagon, and part running -on Android. They need to be compiled separately. Since some source files are -shared between these two parts, make sure to delete all object files between -compilations. Compile the Hexagon code first. +on Android. Each component must be compiled separately. The supported Snapdragon architectures are 855, 865, and 888. @@ -33,45 +31,89 @@ The supported Snapdragon architectures are 855, 865, and 888. Android NDK can be downloaded from https://developer.android.com/ndk. Hexagon SDK is available at //developer.qualcomm.com/software/hexagon-dsp-sdk. -### Compilation of the Hexagon part +### Compilation with TVM -1. Build the static version of TVM runtime for Hexagon. Use Hexagon clang - from the Hexagon SDK. This step is the same as building the shared version, - except at the cmake step, add `-DBUILD_STATIC_RUNTIME=ON`. The compilation - step should create `libtvm_runtime.a`. +Building the Hexagon launcher application as a component of the main TVM build +used for Hexagon codegen can be achieved by setting `USE_HEXAGON_LAUNCHER=ON`. +This option will compile core tvm, the android launcher binary and its corresponding +tvm_runtime, as well as the Hexagon launcher shared library and its corresponding +tvm_runtime. As described in the [Manual compilation](#Manual compilation) section +each component requires Hexagon and android dependencies. When building the launcher +along with TVM these configurations must be providing when invoking cmake. A minimal +example invocation for compiling TVM along with the Hexagon launcher is included below: -2. Create a subdirectory for the build files, and run `cmake` with the - following variables set: - - `FASTRPC_LIBS=SKEL` - - `USE_HEXAGON_SDK` to the path to the Hexagon SDK - - `CMAKE_C_COMPILER=hexagon-clang` - - `CMAKE_CXX_COMPILER=hexagon-clang++` - - `USE_HEXAGON_ARCH` to one of v65, v66, v68 - - `TVM_RUNTIME_HEXAGON=/path/to/libtvm_runtime.a` _statically_ linked - TVM runtime +``` +cmake -DCMAKE_C_COMPILER=/path/to/clang \ + -DCMAKE_CXX_COMPILER=/path/to/clang++ \ + -DCMAKE_CXX_FLAGS='-stdlib=libc++' \ + -DCMAKE_CXX_STANDARD=14 \ + -DUSE_LLVM=/path/to/llvm/bin/llvm-config \ + -DUSE_HEXAGON_ARCH=v65|v66|v68 \ + -DUSE_HEXAGON_LAUNCHER=ON \ + -DUSE_HEXAGON_SDK=/path/to/hexagon/SDK \ + -DUSE_HEXAGON_TOOLCHAIN=/path/to/hexagon/toolchain/ .. + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DUSE_ANDROID_TOOLCHAIN=/path/to/android-ndk/build/cmake/android.toolchain.cmake \ + .. +``` + +where `v65|v66|v68` means "one of" these architecture versions. +The Hexagon launcher application is an android binary and thus requires the use +of an android toolchain for compilation. Similarly, the Hexagon tvm runtime +requires the use of the Hexagon toolchain and depends on the Hexagon SDK. The +resulting hexagon launcher binaries can be found in the `apps_hexagon_launcher` +subdirectory of the cmake build directory. Please note that the above command +will not build support for Hexagon codegen in the TVM library, for that please +additionally define the `USE_HEXAGON_DEVICE` variable. Also, the LLVM used in +`USE_LLVM` should have Hexagon target built in. + +### Manual compilation - Make sure to provide the path to launcher's `CMakeLists.txt` directory - in `cmake` invocation. +Since some source files are shared between the Hexagon and android builds, +make sure to delete all object files between compilations. Compile the Hexagon +code first. -3. Run `make`. This will create `liblauncher_rpc_skel.so`. +#### Compilation of the Hexagon part + +Create a subdirectory for the build files, and run `cmake` with the +following variables set: + +``` +cmake -DCMAKE_C_COMPILER=/path/to/hexagon-clang \ + -DCMAKE_CXX_COMPILER=/path/to/hexagon-clang++ \ + -DUSE_HEXAGON_ARCH=v65|v66|v68 \ + -DUSE_HEXAGON_SDK=/path/to/hexagon/SDK \ + /path/to/apps/hexagon_launcher/cmake/hexagon +``` -### Compilation of the Android part +Run `make`. This will create `liblauncher_rpc_skel.so`. The static version of +the TVM runtime for Hexagon will be built as a part of the process. -1. Build TVM runtime for Android, using clang for AArch64 from the Android - NDK. Unlike in the Hexagon case, this should be the dynamic library (which - is the default), i.e. `libtvm_runtime.so`. +#### Compilation of the Android part 2. Create a subdirectory for the build files (different from the one used for Hexagon files), and run `cmake` with the following variables set: - - `FASTRPC_LIBS=STUB` - - `USE_HEXAGON_SDK` to the path to the Hexagon SDK - - `CMAKE_C_COMPILER=aarch64-linux-android28-clang` (or later) - - `CMAKE_CXX_COMPILER=aarch64-linux-android28-clang++` (or later) - - `USE_HEXAGON_ARCH` to one of v65, v66, v68 (same as for the Hexagon part) - - `TVM_RUNTIME_ANDROID=/path/to/libtvm_runtime.so` dynamically or - statically linked TVM runtime - -3. Run `make`. This will create `launcher_android`. + +``` +cmake -DCMAKE_TOOLCHAIN_FILE=/path/to/android-ndk/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DUSE_HEXAGON_SDK=/p/Hexagon_SDK/4.3.0.0 + -DUSE_HEXAGON_ARCH=v65|v66|v68 + /path/to/apps/hexagon_launcher/cmake/android +``` + +Run `make`. This will create `launcher_android`. The TVM runtime for Android will +be built as a part of the process. Depending on the version of cmake that you are +using, you may see the following warnings---they can be ignored. + +``` +An old version of CMake is being used that cannot automatically detect +compiler attributes. Compiler identification is being bypassed. Some +values may be wrong or missing. Update to CMake 3.19 or newer to use +CMake's built-in compiler identification. +``` ## Execution diff --git a/apps/hexagon_launcher/cmake/HexagonLauncher.cmake b/apps/hexagon_launcher/cmake/HexagonLauncher.cmake new file mode 100644 index 000000000000..abf877cb67f1 --- /dev/null +++ b/apps/hexagon_launcher/cmake/HexagonLauncher.cmake @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(NOT DEFINED USE_HEXAGON_SDK) + message(SEND_ERROR "Please set USE_HEXAGON_SDK to the location of Hexagon SDK") +endif() +if (NOT DEFINED USE_HEXAGON_ARCH) + message(SEND_ERROR "Please set USE_HEXAGON_ARCH to the Hexagon architecture version") +endif() + +set(TVM_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../../") + +include(ExternalProject) +include("${TVM_SOURCE_DIR}/cmake/modules/HexagonSDK.cmake") + +find_hexagon_sdk_root("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}") + +include_directories(SYSTEM ${HEXAGON_SDK_INCLUDES} ${HEXAGON_REMOTE_ROOT}) + +set(QAIC_EXE "${HEXAGON_QAIC_EXE}") +foreach(INCDIR IN LISTS HEXAGON_SDK_INCLUDES HEXAGON_REMOTE_ROOT) + list(APPEND QAIC_FLAGS "-I${INCDIR}") +endforeach() + +set(LAUNCHER_SRC "${CMAKE_CURRENT_SOURCE_DIR}/../../") +set(CMAKE_SKIP_RPATH TRUE) + +# Qaic for the domain header. +# +# Don't add paths to these filenames, or otherwise cmake may spontaneously +# add -o option to the qaic invocation (with an undesirable path). +set(LAUNCHER_RPC_IDL "launcher_rpc.idl") +set(LAUNCHER_RPC_H "launcher_rpc.h") +set(LAUNCHER_RPC_SKEL_C "launcher_rpc_skel.c") +set(LAUNCHER_RPC_STUB_C "launcher_rpc_stub.c") + +include_directories( + "${LAUNCHER_SRC}" + "${TVM_SOURCE_DIR}/include" + "${TVM_SOURCE_DIR}/3rdparty/dlpack/include" + "${TVM_SOURCE_DIR}/3rdparty/dmlc-core/include" +) diff --git a/apps/hexagon_launcher/cmake/android/CMakeLists.txt b/apps/hexagon_launcher/cmake/android/CMakeLists.txt new file mode 100644 index 000000000000..7716cde99863 --- /dev/null +++ b/apps/hexagon_launcher/cmake/android/CMakeLists.txt @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.2) +project(HexagonAndroidLauncher C CXX) + +include("${CMAKE_CURRENT_SOURCE_DIR}/../HexagonLauncher.cmake") + +add_custom_command( + OUTPUT ${LAUNCHER_RPC_STUB_C} ${LAUNCHER_RPC_H} + COMMAND ${QAIC_EXE} ${QAIC_FLAGS} "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" + MAIN_DEPENDENCY "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" +) + +include_directories(SYSTEM + "${HEXAGON_SDK_INCLUDES}" + "${HEXAGON_RPCMEM_ROOT}/inc" + "${CMAKE_CURRENT_BINARY_DIR}" # Output of qaic will go here +) + +link_directories(${HEXAGON_REMOTE_ROOT}) + +add_definitions(-DDMLC_USE_LOGGING_LIBRARY=) + +set(STUB_SRCS + "${LAUNCHER_SRC}/launcher_android.cc" + "${LAUNCHER_SRC}/launcher_core.cc" + "${LAUNCHER_SRC}/launcher_main.cc" + "${LAUNCHER_SRC}/launcher_util.cc" +) + +add_executable(launcher_android + "${LAUNCHER_RPC_H}" + "${LAUNCHER_RPC_STUB_C}" + "${STUB_SRCS}" +) + +ExternalProject_Add(android_tvm_runtime + SOURCE_DIR "${TVM_SOURCE_DIR}" + BUILD_COMMAND $(MAKE) runtime + CMAKE_ARGS + "-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE}" + "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" + "-DANDROID_ABI=${ANDROID_ABI}" + "-DCMAKE_CXX_STANDARD=14" + "-DUSE_LIBBACKTRACE=OFF" + "-DUSE_LLVM=OFF" + "-DUSE_RPC=OFF" + INSTALL_COMMAND "" + BUILD_ALWAYS ON +) +ExternalProject_Get_Property(android_tvm_runtime BINARY_DIR) +ExternalProject_Add_Step(android_tvm_runtime copy_binaries + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${BINARY_DIR}/libtvm_runtime.so + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDEES install +) + +add_dependencies(launcher_android android_tvm_runtime) +add_library(a_tvm_runtime SHARED IMPORTED) +set_target_properties(a_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.so") + +target_link_libraries(launcher_android cdsprpc log a_tvm_runtime) diff --git a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt new file mode 100644 index 000000000000..3f99459f3a49 --- /dev/null +++ b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.2) +project(HexagonLauncherRPCSkel C CXX) + +include("${CMAKE_CURRENT_SOURCE_DIR}/../HexagonLauncher.cmake") + +add_custom_command( + OUTPUT ${LAUNCHER_RPC_SKEL_C} ${LAUNCHER_RPC_H} + COMMAND ${QAIC_EXE} ${QAIC_FLAGS} "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" + MAIN_DEPENDENCY "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" +) + +include_directories(SYSTEM + ${HEXAGON_QURT_INCLUDES} + ${CMAKE_CURRENT_BINARY_DIR} # Output of qaic will go here +) + +link_directories(${HEXAGON_QURT_LIBS}) + +add_definitions(-D_MACH_I32=int) +add_definitions(-DDMLC_CXX11_THREAD_LOCAL=0) +add_definitions(-DDMLC_USE_LOGGING_LIBRARY=) + +# Extra compile flags (both C and C++). +set(EXTRA_COMP_FLAGS + "-O3" + "-m${USE_HEXAGON_ARCH}" +) +string(REGEX REPLACE ";" " " EXTRA_COMP_FLAGS_STR "${EXTRA_COMP_FLAGS}") +set(CMAKE_C_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_C_FLAGS}") +set(CMAKE_CXX_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_CXX_FLAGS}") + +set(SKEL_SRCS + "${LAUNCHER_SRC}/launcher_core.cc" + "${LAUNCHER_SRC}/launcher_hexagon.cc" +) + +add_library(launcher_rpc_skel SHARED + "${LAUNCHER_RPC_H}" + "${LAUNCHER_RPC_SKEL_C}" + "${SKEL_SRCS}" +) + +ExternalProject_Add(static_hexagon_tvm_runtime + SOURCE_DIR "${TVM_SOURCE_DIR}" + BUILD_COMMAND $(MAKE) runtime + CMAKE_ARGS + "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" + "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" + "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" + "-DCMAKE_CXX_STANDARD=14" + "-DUSE_LIBBACKTRACE=OFF" + "-DUSE_LLVM=OFF" + "-DUSE_RPC=OFF" + "-DBUILD_STATIC_RUNTIME=ON" + "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" + INSTALL_COMMAND "" + BUILD_ALWAYS ON +) +ExternalProject_Get_Property(static_hexagon_tvm_runtime BINARY_DIR) + +add_dependencies(launcher_rpc_skel static_hexagon_tvm_runtime) +add_library(h_tvm_runtime STATIC IMPORTED) +set_target_properties(h_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.a") + +target_link_libraries(launcher_rpc_skel -Wl,--whole-archive h_tvm_runtime -Wl,--no-whole-archive) + diff --git a/src/runtime/hexagon/launcher/launcher_android.cc b/apps/hexagon_launcher/launcher_android.cc similarity index 100% rename from src/runtime/hexagon/launcher/launcher_android.cc rename to apps/hexagon_launcher/launcher_android.cc diff --git a/src/runtime/hexagon/launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc similarity index 100% rename from src/runtime/hexagon/launcher/launcher_core.cc rename to apps/hexagon_launcher/launcher_core.cc diff --git a/src/runtime/hexagon/launcher/launcher_core.h b/apps/hexagon_launcher/launcher_core.h similarity index 100% rename from src/runtime/hexagon/launcher/launcher_core.h rename to apps/hexagon_launcher/launcher_core.h diff --git a/src/runtime/hexagon/launcher/launcher_hexagon.cc b/apps/hexagon_launcher/launcher_hexagon.cc similarity index 100% rename from src/runtime/hexagon/launcher/launcher_hexagon.cc rename to apps/hexagon_launcher/launcher_hexagon.cc diff --git a/src/runtime/hexagon/launcher/launcher_main.cc b/apps/hexagon_launcher/launcher_main.cc similarity index 100% rename from src/runtime/hexagon/launcher/launcher_main.cc rename to apps/hexagon_launcher/launcher_main.cc diff --git a/src/runtime/hexagon/launcher/launcher_rpc.idl b/apps/hexagon_launcher/launcher_rpc.idl similarity index 100% rename from src/runtime/hexagon/launcher/launcher_rpc.idl rename to apps/hexagon_launcher/launcher_rpc.idl diff --git a/src/runtime/hexagon/launcher/launcher_util.cc b/apps/hexagon_launcher/launcher_util.cc similarity index 100% rename from src/runtime/hexagon/launcher/launcher_util.cc rename to apps/hexagon_launcher/launcher_util.cc diff --git a/src/runtime/hexagon/launcher/launcher_util.h b/apps/hexagon_launcher/launcher_util.h similarity index 100% rename from src/runtime/hexagon/launcher/launcher_util.h rename to apps/hexagon_launcher/launcher_util.h diff --git a/apps/microtvm/arduino/README.md b/apps/microtvm/arduino/README.md new file mode 100644 index 000000000000..b33557b53239 --- /dev/null +++ b/apps/microtvm/arduino/README.md @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + +This directory contains code to interface microTVM with [Arduino](https://www.arduino.cc/). diff --git a/apps/microtvm/arduino/host_driven/src/standalone_crt/crt_config/crt_config.h b/apps/microtvm/arduino/host_driven/src/standalone_crt/crt_config/crt_config.h deleted file mode 100644 index cf73103aff8b..000000000000 --- a/apps/microtvm/arduino/host_driven/src/standalone_crt/crt_config/crt_config.h +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief CRT configuration for the host-linked CRT. - */ -#ifndef TVM_RUNTIME_MICRO_CRT_CONFIG_H_ -#define TVM_RUNTIME_MICRO_CRT_CONFIG_H_ - -/*! Log level of the CRT runtime */ -#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG - -/*! Support low-level debugging in MISRA-C runtime */ -#define TVM_CRT_DEBUG 0 - -/*! Maximum supported dimension in NDArray */ -#define TVM_CRT_MAX_NDIM 6 -/*! Maximum supported arguments in generated functions */ -#define TVM_CRT_MAX_ARGS 10 -/*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ -#define TVM_CRT_MAX_STRLEN_DLTYPE 10 -/*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 - -/*! Maximum number of registered modules. */ -#define TVM_CRT_MAX_REGISTERED_MODULES 2 - -/*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512 - -/*! Maximum packet size, in bytes, including the length header. */ -#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8 * 1024 - -/*! \brief Maximum length of a PackedFunc function name. */ -#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 - -// #define TVM_CRT_FRAMER_ENABLE_LOGS - -#endif // TVM_RUNTIME_MICRO_CRT_CONFIG_H_ diff --git a/apps/microtvm/arduino/template_project/boards.json b/apps/microtvm/arduino/template_project/boards.json new file mode 100644 index 000000000000..595d56b5f615 --- /dev/null +++ b/apps/microtvm/arduino/template_project/boards.json @@ -0,0 +1,59 @@ +{ + "due": { + "package": "arduino", + "architecture": "sam", + "board": "arduino_due_x_dbg", + "model": "sam3x8e" + }, + "feathers2": { + "package": "esp32", + "architecture": "esp32", + "board": "feathers2", + "model": "esp32", + "note": "Due to the way the Feather S2 bootloader works, compilation behaves fine but uploads cannot be done automatically." + }, + "metrom4": { + "package": "adafruit", + "architecture": "samd", + "board": "adafruit_metro_m4", + "model": "atsamd51" + }, + "spresense": { + "package": "SPRESENSE", + "architecture": "spresense", + "board": "spresense", + "model": "cxd5602gg", + "note": "Spresense only works as of its v2.3.0 sdk." + }, + "nano33ble": { + "package": "arduino", + "architecture": "mbed_nano", + "board": "nano33ble", + "model": "nrf52840" + }, + "pybadge": { + "package": "adafruit", + "architecture": "samd", + "board": "adafruit_pybadge_m4", + "model": "atsamd51" + }, + "teensy40": { + "package": "teensy", + "architecture": "avr", + "board": "teensy40", + "model": "imxrt1060", + "note": "The Teensy boards are listed here for completeness, but they won't work until https://github.com/arduino/arduino-cli/issues/700 is finished." + }, + "teensy41": { + "package": "teensy", + "architecture": "avr", + "board": "teensy41", + "model": "imxrt1060" + }, + "wioterminal": { + "package": "Seeeduino", + "architecture": "samd", + "board": "seeed_wio_terminal", + "model": "atsamd51" + } +} diff --git a/apps/microtvm/arduino/example_project/src/standalone_crt/crt_config/crt_config.h b/apps/microtvm/arduino/template_project/crt_config/crt_config.h similarity index 93% rename from apps/microtvm/arduino/example_project/src/standalone_crt/crt_config/crt_config.h rename to apps/microtvm/arduino/template_project/crt_config/crt_config.h index cf73103aff8b..b3126cfac920 100644 --- a/apps/microtvm/arduino/example_project/src/standalone_crt/crt_config/crt_config.h +++ b/apps/microtvm/arduino/template_project/crt_config/crt_config.h @@ -36,7 +36,9 @@ /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index 3d25d0bcad8f..e285ecc6e3b0 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -44,77 +44,21 @@ IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists() +BOARDS = API_SERVER_DIR / "boards.json" + +# Data structure to hold the information microtvm_api_server.py needs +# to communicate with each of these boards. +try: + with open(BOARDS) as boards: + BOARD_PROPERTIES = json.load(boards) +except FileNotFoundError: + raise FileNotFoundError(f"Board file {{{BOARDS}}} does not exist.") + + class BoardAutodetectFailed(Exception): """Raised when no attached hardware is found matching the requested board""" -# Data structure to hold the information microtvm_api_server.py needs -# to communicate with each of these boards. Currently just holds the -# components of each board's FQBN, but might be extended in the future -# to include the SRAM, PSRAM, flash, etc. on each board. -BOARD_PROPERTIES = { - "due": { - "package": "arduino", - "architecture": "sam", - "board": "arduino_due_x_dbg", - "model": "sam3x8e", - }, - # Due to the way the Feather S2 bootloader works, compilation - # behaves fine but uploads cannot be done automatically - "feathers2": { - "package": "esp32", - "architecture": "esp32", - "board": "feathers2", - "model": "esp32", - }, - "metrom4": { - "package": "adafruit", - "architecture": "samd", - "board": "adafruit_metro_m4", - "model": "atsamd51", - }, - # Spresense only works as of its v2.3.0 sdk - "spresense": { - "package": "SPRESENSE", - "architecture": "spresense", - "board": "spresense", - "model": "cxd5602gg", - }, - "nano33ble": { - "package": "arduino", - "architecture": "mbed_nano", - "board": "nano33ble", - "model": "nrf52840", - }, - "pybadge": { - "package": "adafruit", - "architecture": "samd", - "board": "adafruit_pybadge_m4", - "model": "atsamd51", - }, - # The Teensy boards are listed here for completeness, but they - # won't work until https://github.com/arduino/arduino-cli/issues/700 - # is finished - "teensy40": { - "package": "teensy", - "architecture": "avr", - "board": "teensy40", - "model": "imxrt1060", - }, - "teensy41": { - "package": "teensy", - "architecture": "avr", - "board": "teensy41", - "model": "imxrt1060", - }, - "wioterminal": { - "package": "Seeeduino", - "architecture": "samd", - "board": "seeed_wio_terminal", - "model": "atsamd51", - }, -} - PROJECT_TYPES = ["example_project", "host_driven"] PROJECT_OPTIONS = [ @@ -123,11 +67,6 @@ class BoardAutodetectFailed(Exception): choices=list(BOARD_PROPERTIES), help="Name of the Arduino board to build for", ), - server.ProjectOption( - "arduino_model", - choices=[board["model"] for _, board in BOARD_PROPERTIES.items()], - help="Name of the model for each Arduino board.", - ), server.ProjectOption("arduino_cli_cmd", help="Path to the arduino-cli tool."), server.ProjectOption("port", help="Port to use for connecting to hardware"), server.ProjectOption( @@ -166,8 +105,9 @@ def _copy_project_files(self, api_server_dir, project_dir, project_type): so this file is copied separately in generate_project. """ - project_types_folder = api_server_dir.parents[0] - for item in (project_types_folder / project_type / "src").iterdir(): + for item in (API_SERVER_DIR / "src" / project_type).iterdir(): + if item.name == "project.ino": + continue dest = project_dir / "src" / item.name if item.is_dir(): shutil.copytree(item, dest) @@ -176,7 +116,7 @@ def _copy_project_files(self, api_server_dir, project_dir, project_type): # Arduino requires the .ino file have the same filename as its containing folder shutil.copy2( - project_types_folder / project_type / "project.ino", + API_SERVER_DIR / "src" / project_type / "project.ino", project_dir / f"{project_dir.stem}.ino", ) @@ -344,12 +284,20 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec # Copies files from the template folder to project_dir shutil.copy2(API_SERVER_DIR / "microtvm_api_server.py", project_dir) + shutil.copy2(BOARDS, project_dir / BOARDS.name) self._copy_project_files(API_SERVER_DIR, project_dir, options["project_type"]) # Copy standalone_crt into src folder self._copy_standalone_crt(source_dir, standalone_crt_dir) self._remove_unused_components(source_dir, options["project_type"]) + # Populate crt-config.h + crt_config_dir = project_dir / "src" / "standalone_crt" / "crt_config" + crt_config_dir.mkdir() + shutil.copy2( + API_SERVER_DIR / "crt_config" / "crt_config.h", crt_config_dir / "crt_config.h" + ) + # Unpack the MLF and copy the relevant files metadata = self._disassemble_mlf(model_library_format_path, source_dir) shutil.copy2(model_library_format_path, source_dir / "model") diff --git a/apps/microtvm/arduino/example_project/src/model.c b/apps/microtvm/arduino/template_project/src/example_project/model.c similarity index 100% rename from apps/microtvm/arduino/example_project/src/model.c rename to apps/microtvm/arduino/template_project/src/example_project/model.c diff --git a/apps/microtvm/arduino/example_project/src/model.h b/apps/microtvm/arduino/template_project/src/example_project/model.h similarity index 100% rename from apps/microtvm/arduino/example_project/src/model.h rename to apps/microtvm/arduino/template_project/src/example_project/model.h diff --git a/apps/microtvm/arduino/example_project/project.ino b/apps/microtvm/arduino/template_project/src/example_project/project.ino similarity index 100% rename from apps/microtvm/arduino/example_project/project.ino rename to apps/microtvm/arduino/template_project/src/example_project/project.ino diff --git a/apps/microtvm/arduino/host_driven/src/model_support.c b/apps/microtvm/arduino/template_project/src/host_driven/model_support.c similarity index 100% rename from apps/microtvm/arduino/host_driven/src/model_support.c rename to apps/microtvm/arduino/template_project/src/host_driven/model_support.c diff --git a/apps/microtvm/arduino/host_driven/project.ino b/apps/microtvm/arduino/template_project/src/host_driven/project.ino similarity index 100% rename from apps/microtvm/arduino/host_driven/project.ino rename to apps/microtvm/arduino/template_project/src/host_driven/project.ino diff --git a/apps/microtvm/reference-vm/base-box-tool.py b/apps/microtvm/reference-vm/base-box-tool.py index 3a5fd18cede7..42b90c661704 100755 --- a/apps/microtvm/reference-vm/base-box-tool.py +++ b/apps/microtvm/reference-vm/base-box-tool.py @@ -388,13 +388,15 @@ def test_command(args): microtvm_test_config["microtvm_board"] = args.microtvm_board providers = args.provider - provider_passed = {p: False for p in providers} release_test_dir = os.path.join(THIS_DIR, f"release-test-{args.platform}") - if args.skip_build: - assert len(providers) == 1, "--skip-build was given, but >1 provider specified" + if args.skip_build or args.skip_destroy: + assert ( + len(providers) == 1 + ), "--skip-build and/or --skip-destroy was given, but >1 provider specified" + test_failed = False for provider_name in providers: try: if not args.skip_build: @@ -408,18 +410,27 @@ def test_command(args): microtvm_test_config, args.test_device_serial, ) - provider_passed[provider_name] = True + + except subprocess.CalledProcessError: + test_failed = True + sys.exit( + f"\n\nERROR: Provider '{provider_name}' failed the release test. " + "You can re-run it to reproduce the issue without building everything " + "again by passing the --skip-build and specifying only the provider that failed. " + "The VM is still running in case you want to connect it via SSH to " + "investigate further the issue, thus it's necessary to destroy it manually " + "to release the resources back to the host, like a USB device attached to the VM." + ) finally: - if not args.skip_build and len(providers) > 1: + # if we reached out here do_run_release_test() succeeded, hence we can + # destroy the VM and release the resources back to the host if user haven't + # requested to not destroy it. + if not (args.skip_destroy or test_failed): subprocess.check_call(["vagrant", "destroy", "-f"], cwd=release_test_dir) shutil.rmtree(release_test_dir) - if not all(provider_passed[p] for p in provider_passed.keys()): - sys.exit( - "some providers failed release test: " - + ",".join(name for name, passed in provider_passed if not passed) - ) + print(f'\n\nThe release tests passed on all specified providers: {", ".join(providers)}.') def release_command(args): @@ -493,9 +504,20 @@ def parse_args(): "--skip-build", action="store_true", help=( - "If given, assume a box has already been built in " - "the release-test subdirectory. Attach a USB device to this box and execute the " - "release test script--do not delete it." + "If given, assume a box has already been built in the release-test subdirectory, " + "so use that box to execute the release test script. If the tests fail the VM used " + "for testing will be left running for further investigation and will need to be " + "destroyed manually. If all tests pass on all specified providers no VM is left running, " + "unless --skip-destroy is given too." + ), + ) + parser_test.add_argument( + "--skip-destroy", + action="store_true", + help=( + "Skip destroying the test VM even if all tests pass. Can only be used if a single " + "provider is specified. Default is to destroy the VM if all tests pass (and always " + "skip destroying it if a test fails)." ), ) parser_test.add_argument( diff --git a/apps/microtvm/zephyr/README.md b/apps/microtvm/zephyr/README.md index ad00393c0805..68e9975d4b1c 100644 --- a/apps/microtvm/zephyr/README.md +++ b/apps/microtvm/zephyr/README.md @@ -15,5 +15,5 @@ -This directory code to interface microTVM with the [Zephyr RTOS](https://zephyrproject.org/). +This directory contains code to interface microTVM with the [Zephyr RTOS](https://zephyrproject.org/). diff --git a/apps/microtvm/zephyr/template_project/boards.json b/apps/microtvm/zephyr/template_project/boards.json index aabed3322150..18e393897f04 100644 --- a/apps/microtvm/zephyr/template_project/boards.json +++ b/apps/microtvm/zephyr/template_project/boards.json @@ -39,7 +39,7 @@ "board": "qemu_riscv32", "model": "host", "is_qemu": true, - "fpu": true + "fpu": false }, "qemu_riscv64": { "board": "qemu_riscv64", diff --git a/apps/microtvm/zephyr/template_project/crt_config/crt_config.h b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h index f8fc7514a28d..c3beaed522f2 100644 --- a/apps/microtvm/zephyr/template_project/crt_config/crt_config.h +++ b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h @@ -36,7 +36,7 @@ #define TVM_CRT_MAX_ARGS 10 /*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200 +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 @@ -48,7 +48,10 @@ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 + +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 diff --git a/apps/microtvm/zephyr/template_project/src/host_driven/main.c b/apps/microtvm/zephyr/template_project/src/host_driven/main.c index 43064e804193..44d656028cbc 100644 --- a/apps/microtvm/zephyr/template_project/src/host_driven/main.c +++ b/apps/microtvm/zephyr/template_project/src/host_driven/main.c @@ -260,11 +260,6 @@ void uart_rx_init(struct ring_buf* rbuf, const struct device* dev) { // The main function of this application. extern void __stdout_hook_install(int (*hook)(int)); void main(void) { - // TODO (mehrdadh): Update this when zephyr version has updated to 2.6. - // Update zephyr to latest version to use with qemu_riscv32. -#ifdef CONFIG_BOARD_QEMU_RISCV32 - k_float_enable(_current, 0); -#endif #ifdef CONFIG_LED int ret; diff --git a/cmake/config.cmake b/cmake/config.cmake index ade9d5c815c1..1fce11f90aed 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -279,6 +279,9 @@ set(USE_FALLBACK_STL_MAP OFF) set(USE_HEXAGON_DEVICE OFF) set(USE_HEXAGON_SDK /path/to/sdk) +# Whether to build the hexagon launcher +set(USE_HEXAGON_LAUNCHER OFF) + # Hexagon architecture to target when compiling TVM itself (not the target for # compiling _by_ TVM). This applies to components like the TVM runtime, but is # also used to select correct include/library paths from the Hexagon SDK when diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index eb3ad1f5ae4a..88623ab045fd 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -53,23 +53,87 @@ if(BUILD_FOR_HEXAGON) include_directories(SYSTEM ${HEXAGON_SDK_INCLUDES} ${HEXAGON_QURT_INCLUDES}) endif() -if(USE_HEXAGON_DEVICE STREQUAL "OFF") - list(APPEND COMPILER_SRCS src/target/opt/build_hexagon_off.cc) - return() -elseif(NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}" AND - NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") - set(ERROR_MSG +if(USE_HEXAGON_LAUNCHER STREQUAL "ON") + set(USE_HEXAGON_DEVICE "${PICK_SIM}") +else() + if(USE_HEXAGON_DEVICE STREQUAL "OFF") + list(APPEND COMPILER_SRCS src/target/opt/build_hexagon_off.cc) + return() + elseif(NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}" AND + NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") + set(ERROR_MSG "USE_HEXAGON_DEVICE must be one of [${PICK_NONE}|${PICK_SIM}|${PICK_HW}]") - message(SEND_ERROR "${ERROR_MSG}") - return() + message(SEND_ERROR "${ERROR_MSG}") + return() + endif() endif() -# If USE_HEXAGON_DEVICE is set to a valid value, make sure that USE_HEXAGON_SDK + +# If USE_HEXAGON_DEVICE/LAUNCHER is set to a valid value, make sure that USE_HEXAGON_SDK # is defined. if(NOT USE_HEXAGON_SDK) message(SEND_ERROR "Please set USE_HEXAGON_SDK to the Hexagon SDK root") return() endif() +if(USE_HEXAGON_LAUNCHER STREQUAL "ON") + if(DEFINED USE_ANDROID_TOOLCHAIN) + if(NOT DEFINED ANDROID_PLATFORM) + message(SEND_ERROR "Please set ANDROID_PLATFORM " + "when providing an Android cmake toolchain.") + endif() + if(NOT DEFINED ANDROID_ABI) + message(SEND_ERROR "Please set ANDROID_ABI " + "when providing an Android cmake toolchain.") + endif() + else() + message(SEND_ERROR "Please set USE_ANDROID_TOOLCHAIN to build the android " + " launcher for hexagon.") + endif() + + set(LAUNCHER_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/apps_hexagon_launcher") + ExternalProject_Add(launcher_android + SOURCE_DIR "${CMAKE_SOURCE_DIR}/apps/hexagon_launcher/cmake/android" + INSTALL_DIR "${LAUNCHER_BINARY_DIR}" + BUILD_ALWAYS ON + CMAKE_ARGS + "-DCMAKE_TOOLCHAIN_FILE=${USE_ANDROID_TOOLCHAIN}" + "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" + "-DANDROID_ABI=${ANDROID_ABI}" + "-DFASTRPC_LIBS=STUB" + "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" + "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" + INSTALL_COMMAND "" + ) + ExternalProject_Get_Property(launcher_android BINARY_DIR) + ExternalProject_Add_Step(launcher_android copy_binaries + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${BINARY_DIR}/launcher_android ${BINARY_DIR}/libtvm_runtime.so + ${LAUNCHER_BINARY_DIR} + DEPENDEES install + ) + ExternalProject_Add(launcher_hexagon + SOURCE_DIR "${CMAKE_SOURCE_DIR}/apps/hexagon_launcher/cmake/hexagon" + INSTALL_DIR "${LAUNCHER_BINARY_DIR}" + BUILD_ALWAYS ON + CMAKE_ARGS + "-DCMAKE_C_COMPILER=${USE_HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang" + "-DCMAKE_CXX_COMPILER=${USE_HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++" + "-DFASTRPC_LIBS=SKEL" + "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" + "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" + INSTALL_COMMAND "" + ) + ExternalProject_Get_Property(launcher_hexagon BINARY_DIR) + ExternalProject_Add_Step(launcher_hexagon copy_binaries + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${BINARY_DIR}/liblauncher_rpc_skel.so + ${LAUNCHER_BINARY_DIR} + DEPENDEES install + ) + + set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES "${LAUNCHER_BINARY_DIR}") +endif() + if(USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}") find_hexagon_toolchain() message(STATUS "Hexagon toolchain: ${HEXAGON_TOOLCHAIN}") diff --git a/cmake/modules/contrib/ExampleTargetHooks.cmake b/cmake/modules/contrib/ExampleTargetHooks.cmake index eb53dda133d2..e9003b02103e 100644 --- a/cmake/modules/contrib/ExampleTargetHooks.cmake +++ b/cmake/modules/contrib/ExampleTargetHooks.cmake @@ -15,5 +15,5 @@ # specific language governing permissions and limitations # under the License. -file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc) +file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/*.cc) list(APPEND COMPILER_SRCS ${EXAMPLE_TARGET_HOOKS_SRC}) diff --git a/docker/install/ubuntu_install_ethosu_driver_stack.sh b/docker/install/ubuntu_install_ethosu_driver_stack.sh index 35b2b4c74b7b..db8b47399390 100755 --- a/docker/install/ubuntu_install_ethosu_driver_stack.sh +++ b/docker/install/ubuntu_install_ethosu_driver_stack.sh @@ -24,7 +24,7 @@ fvp_dir="/opt/arm/FVP_Corstone_SSE-300_Ethos-U55" cmake_dir="/opt/arm/cmake" ethosu_dir="/opt/arm/ethosu" ethosu_driver_ver="21.05" -cmsis_ver="5.7.0" +cmsis_ver="5.8.0" mkdir -p /opt/arm @@ -92,3 +92,13 @@ cd "${ethosu_dir}" git clone "https://github.com/ARM-software/CMSIS_5.git" cmsis cd cmsis git checkout -f tags/${cmsis_ver} + +# Build Driver +mkdir ${ethosu_dir}/core_driver/build && cd ${ethosu_dir}/core_driver/build +cmake -DCMAKE_TOOLCHAIN_FILE=${ethosu_dir}/core_platform/cmake/toolchain/arm-none-eabi-gcc.cmake -DETHOSU_LOG_SEVERITY=debug -DTARGET_CPU=cortex-m55 .. +make + +# Build NN Library +mkdir ${ethosu_dir}/cmsis/CMSIS/NN/build/ && cd ${ethosu_dir}/cmsis/CMSIS/NN/build/ +cmake .. -DCMAKE_TOOLCHAIN_FILE=${ethosu_dir}/core_platform/cmake/toolchain/arm-none-eabi-gcc.cmake -DTARGET_CPU=cortex-m55 -DBUILD_CMSIS_NN_FUNCTIONS=YES +make diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index d1fa340ac37d..fb0f596d6552 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -36,6 +36,6 @@ pip3 install \ pytest-xdist \ requests \ scipy \ - synr==0.4.1 \ + synr==0.5.0 \ six \ tornado diff --git a/docker/install/ubuntu_install_qemu.sh b/docker/install/ubuntu_install_qemu.sh index 1189f2bb8dd4..6682795b0fd8 100755 --- a/docker/install/ubuntu_install_qemu.sh +++ b/docker/install/ubuntu_install_qemu.sh @@ -54,7 +54,7 @@ apt update apt-get -y build-dep qemu gpg --keyserver keyserver.ubuntu.com --recv-keys 0x3353C9CEF108B584 -cat <qemu-5.1.0.tar.xz.sig +cat <${QEMU_SIG_FILE} -----BEGIN PGP ARMORED FILE----- Comment: Use "gpg --dearmor" for unpacking @@ -68,7 +68,7 @@ p5ez/+2k4VAIwIQoP5DoO06waLBffvLIAdPPKYsx71K67OoGG2svc7duC/+5qf1x =hCS7 -----END PGP ARMORED FILE----- EOF -curl -OLs https://download.qemu.org/qemu-5.1.0.tar.xz +curl -OLs https://download.qemu.org/${QEMU_TAR_FILE} gpg --verify ${QEMU_SIG_FILE} tar -xf ${QEMU_TAR_FILE} diff --git a/docker/install/ubuntu_install_sphinx.sh b/docker/install/ubuntu_install_sphinx.sh index 66720d411832..12208bbe6643 100755 --- a/docker/install/ubuntu_install_sphinx.sh +++ b/docker/install/ubuntu_install_sphinx.sh @@ -29,5 +29,5 @@ pip3 install \ matplotlib \ sphinx \ sphinx_autodoc_annotation \ - sphinx-gallery==0.4.1 \ + sphinx-gallery==0.4.0 \ sphinx_rtd_theme diff --git a/docs/arch/relay_op_strategy.rst b/docs/arch/relay_op_strategy.rst index c40251d22433..dbac7c821827 100644 --- a/docs/arch/relay_op_strategy.rst +++ b/docs/arch/relay_op_strategy.rst @@ -269,14 +269,14 @@ will then be chosen. Implementations with same priority level in this case leads to an undefined behavior, and any of them might be selected. The selection policy for ops with symbolic input shapes is still work in -progess. Currently, if any input tensor has a symbolic shape, only the +progress. Currently, if any input tensor has a symbolic shape, only the implementation with highest priority level will be used for this operator. This -will be updated after the implemention finishes. +will be updated after the implementation finishes. For debug purpose, you can add the following lines before you compile the Relay model to learn which implementation is used for each operator. .. code:: python - logging.getLogger("compile_engine").setLevel(logging.INFO) - logging.getLogger("compile_engine").addHandler(logging.StreamHandler(sys.stdout)) + logging.getLogger("te_compiler").setLevel(logging.INFO) + logging.getLogger("te_compiler").addHandler(logging.StreamHandler(sys.stdout)) diff --git a/docs/dev/how_to/relay_add_op.rst b/docs/dev/how_to/relay_add_op.rst index f9ade45f0800..2a8c771dc63d 100644 --- a/docs/dev/how_to/relay_add_op.rst +++ b/docs/dev/how_to/relay_add_op.rst @@ -190,18 +190,16 @@ useful for fusing operators. ``kOpaque`` tells TVM to not bother trying to fuse While we've now defined the interface for our operations we still need to define how to perform the actual calculations for cumulative sum and product. -Writing this code is outside the scope of the tutorial. For now, we assume -we have a well tested implementation for the operation's compute. For -more details on how to do this, we recommend looking up the tutorials -on `tensor expressions`_, `TVM's operator inventory (topi)`_ and looking at the -example cumulative sum and product implementations found in `python/tvm/topi/scan.py`_ -and the gpu versions in `python/tvm/topi/cuda/scan.py`_. In the case of our cumulative -sum and product operations we write things directly in `TIR`_ which is the +Writing this code is outside the scope of the tutorial. For now, we assume we +have a well tested implementation for the operation's compute. For more details +on how to do this, we recommend looking up the tutorials on :ref:`tensor +expressions `, :ref:`TVM's operator inventory +(topi) ` and looking at the example cumulative sum and product +implementations found in `python/tvm/topi/scan.py`_ and the gpu versions in +`python/tvm/topi/cuda/scan.py`_. In the case of our cumulative sum and product +operations we write things directly in :ref:`TIR ` which is the representation where tensor expressions and topi will lower into. -.. _tensor expressions: https://tvm.apache.org/docs/tutorials/get_started/tensor_expr_get_started.html -.. _TVM's operator inventory (topi): https://tvm.apache.org/docs/tutorials/topi/intro_topi.html -.. _TIR: https://tvm.apache.org/docs/dev/index.html?highlight=tir#tvm-tir .. _python/tvm/topi/scan.py: https://github.com/apache/tvm/blob/main/python/tvm/topi/scan.py .. _python/tvm/topi/cuda/scan.py: https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/scan.py diff --git a/docs/how_to/deploy/arm_compute_lib.rst b/docs/how_to/deploy/arm_compute_lib.rst index 6fb531a0a8f6..831438273cca 100644 --- a/docs/how_to/deploy/arm_compute_lib.rst +++ b/docs/how_to/deploy/arm_compute_lib.rst @@ -142,9 +142,9 @@ Export the module. lib.export_library(lib_path, cc=cross_compile) -Run Inference. This must be on an Arm device. If compiling on x86 device and running on AArch64, -consider using the RPC mechanism. Tutorials for using the RPC mechanism: -https://tvm.apache.org/docs/tutorials/get_started/cross_compilation_and_rpc.html +Run Inference. This must be on an Arm device. If compiling on x86 device and +running on AArch64, consider using the RPC mechanism. :ref:`Tutorials for using +the RPC mechanism ` .. code:: python diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index b28c18162437..23be3198bf7c 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -107,7 +107,7 @@ The configuration of TVM can be modified by editing `config.cmake` and/or by pas .. code:: bash - export TVM_LOG_DEBUG=1 + export TVM_LOG_DEBUG="ir/transform.cc=1;relay/ir/transform.cc=1" - TVM requires LLVM for for CPU codegen. We highly recommend you to build with the LLVM support on. diff --git a/docs/reference/api/python/relay/backend.rst b/docs/reference/api/python/relay/backend.rst index ffe8a9a8ce79..e717ee10ffab 100644 --- a/docs/reference/api/python/relay/backend.rst +++ b/docs/reference/api/python/relay/backend.rst @@ -23,7 +23,7 @@ tvm.relay.backend .. automodule:: tvm.relay.backend.interpreter :members: -.. automodule:: tvm.relay.backend.compile_engine +.. automodule:: tvm.relay.backend.te_compiler :members: .. automodule:: tvm.relay.backend.graph_executor_codegen diff --git a/docs/reference/api/python/tir.rst b/docs/reference/api/python/tir.rst index b0b8f1cff5fb..2152be69ea6f 100644 --- a/docs/reference/api/python/tir.rst +++ b/docs/reference/api/python/tir.rst @@ -15,6 +15,8 @@ specific language governing permissions and limitations under the License. +.. _api-python-tir: + tvm.tir ------- .. automodule:: tvm.tir diff --git a/docs/topic/vta/install.rst b/docs/topic/vta/install.rst index 2248975b61b1..e4b309ea9b61 100644 --- a/docs/topic/vta/install.rst +++ b/docs/topic/vta/install.rst @@ -30,8 +30,8 @@ We present three installation guides, each extending on the previous one: VTA Simulator Installation -------------------------- -You need `TVM installed `_ on your machine. -For a quick and easy start, checkout the `Docker Guide `_. +You need :ref:`TVM installed ` on your machine. For a quick and +easy start, checkout the :ref:`Docker Guide `. You'll need to set the following paths to use VTA: @@ -65,7 +65,7 @@ To ensure that you've properly installed the VTA python package, run the followi python /vta/tests/python/integration/test_benchmark_topi_conv2d.py -You are invited to try out our `VTA programming tutorials `_. +You are invited to try out our :ref:`VTA programming tutorials `. **Note**: You'll notice that for every convolution layer, the throughput gets reported in GOPS. These numbers are actually the computational throughput that the simulator achieves, by evaluating the convolutions in software. @@ -222,9 +222,7 @@ The performance metrics measured on the Pynq board will be reported for each con **Tip**: You can track progress of the FPGA programming and the runtime rebuilding steps by looking at the RPC server's logging messages in your Pynq ``ssh`` session. -You can also try out our `VTA programming tutorials `_. - - +You can also try out our :ref:`VTA programming tutorials `. Intel DE10 FPGA Setup --------------------- diff --git a/gallery/how_to/deploy_models/deploy_prequantized_tflite.py b/gallery/how_to/deploy_models/deploy_prequantized_tflite.py index 7bbb06bdf801..830e2ab07466 100644 --- a/gallery/how_to/deploy_models/deploy_prequantized_tflite.py +++ b/gallery/how_to/deploy_models/deploy_prequantized_tflite.py @@ -255,8 +255,8 @@ def run_tvm(lib): # * Set the environment variable TVM_NUM_THREADS to the number of physical cores # * Choose the best target for your hardware, such as "llvm -mcpu=skylake-avx512" or # "llvm -mcpu=cascadelake" (more CPUs with AVX512 would come in the future) -# * Perform autotuning - `Auto-tuning a convolution network for x86 CPU -# `_. -# * To get best inference performance on ARM CPU, change target argument according to your -# device and follow `Auto-tuning a convolution network for ARM CPU -# `_. +# * Perform autotuning - :ref:`Auto-tuning a convolution network for x86 CPU +# `. +# * To get best inference performance on ARM CPU, change target argument +# according to your device and follow :ref:`Auto-tuning a convolution +# network for ARM CPU `. diff --git a/gallery/how_to/work_with_schedules/schedule_primitives.py b/gallery/how_to/work_with_schedules/schedule_primitives.py index ade79f69707f..65fdeda57c3b 100644 --- a/gallery/how_to/work_with_schedules/schedule_primitives.py +++ b/gallery/how_to/work_with_schedules/schedule_primitives.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """ +.. _schedule_primitives: + Schedule Primitives in TVM ========================== **Author**: `Ziheng Jiang `_ diff --git a/gallery/tutorial/autotvm_relay_x86.py b/gallery/tutorial/autotvm_relay_x86.py index 67faec4505a6..67b832cc226d 100644 --- a/gallery/tutorial/autotvm_relay_x86.py +++ b/gallery/tutorial/autotvm_relay_x86.py @@ -81,10 +81,9 @@ # # .. note:: Working with Other Model Formats # -# TVM supports many popular model formats. A list can be found in the `Compile -# Deep Learning Models -# `_ -# section of the TVM Documentation. +# TVM supports many popular model formats. A list can be found in the +# :ref:`Compile Deep Learning Models ` section of the TVM +# Documentation. model_url = "".join( [ @@ -107,7 +106,7 @@ # TVMC has adopted NumPy's ``.npz`` format for both input and output data. # # As input for this tutorial, we will use the image of a cat, but you can feel -# free to substitute image for any of your choosing. +# free to substitute this image for any of your choosing. # # .. image:: https://s3.amazonaws.com/model-server/inputs/kitten.jpg # :height: 224px @@ -150,9 +149,8 @@ # # Specifying the correct target can have a huge impact on the performance of # the compiled module, as it can take advantage of hardware features -# available on the target. For more information, please refer to `Auto-tuning -# a convolutional network for x86 CPU -# `_. +# available on the target. For more information, please refer to +# :ref:`Auto-tuning a convolutional network for x86 CPU `. # We recommend identifying which CPU you are running, along with optional # features, and set the target appropriately. For example, for some # processors ``target = "llvm -mcpu=skylake"``, or ``target = "llvm @@ -280,6 +278,7 @@ from tvm.autotvm.tuner import XGBTuner from tvm import autotvm +################################################################################ # Set up some basic parameters for the runner. The runner takes compiled code # that is generated with a specific set of parameters and measures the # performance of it. ``number`` specifies the number of different @@ -305,6 +304,7 @@ enable_cpu_cache_flush=True, ) +################################################################################ # Create a simple structure for holding tuning options. We use an XGBoost # algorithim for guiding the search. For a production job, you will want to set # the number of trials to be larger than the value of 10 used here. For CPU we @@ -428,6 +428,7 @@ for rank in ranks[0:5]: print("class='%s' with probability=%f" % (labels[rank], scores[rank])) +################################################################################ # Verifying that the predictions are the same: # # .. code-block:: bash diff --git a/gallery/tutorial/install.py b/gallery/tutorial/install.py index b69b8b493a4f..67ce093b9d7f 100644 --- a/gallery/tutorial/install.py +++ b/gallery/tutorial/install.py @@ -35,8 +35,8 @@ # allow you to enable specific features such as GPU support, microcontroller # support (microTVM), and a debugging runtime, and other features. You will also # want to install from source if you want to actively contribute to the TVM -# project. The full instructions are on the `Install TVM From Source -# `_ page. +# project. The full instructions are on the :ref:`Install TVM From Source +# ` page. ################################################################################ # Installing From Binary Packages diff --git a/gallery/tutorial/intro_topi.py b/gallery/tutorial/intro_topi.py index 8138e4718cd9..dad8c53bf4ae 100644 --- a/gallery/tutorial/intro_topi.py +++ b/gallery/tutorial/intro_topi.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """ +.. _tutorial-topi: + Introduction to TOPI ==================== **Author**: `Ehsan M. Kermani `_ diff --git a/gallery/tutorial/tensor_expr_get_started.py b/gallery/tutorial/tensor_expr_get_started.py index 310d6bdbfee4..e4d947d1c488 100644 --- a/gallery/tutorial/tensor_expr_get_started.py +++ b/gallery/tutorial/tensor_expr_get_started.py @@ -133,7 +133,7 @@ ################################################################################ # Let's run the function, and compare the output to the same computation in -# numpy. The compiled TVM function is exposes a concise C API that can be invoked +# numpy. The compiled TVM function exposes a concise C API that can be invoked # from any language. We begin by creating a device, which is a device (CPU in this # example) that TVM can compile the schedule to. In this case the device is an # LLVM CPU target. We can then initialize the tensors in our device and @@ -258,8 +258,8 @@ def evaluate_addition(func, target, optimization, log): print(tvm.lower(s, [A, B, C], simple_mode=True)) ################################################################################ -# Comparing the Diferent Schedules -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Comparing the Different Schedules +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # We can now compare the different schedules baseline = log[0][1] @@ -347,7 +347,7 @@ def evaluate_addition(func, target, optimization, log): fadd = tvm.build(s, [A, B, C], target=tgt_gpu, name="myadd") ################################################################################ - # The compiled TVM function is exposes a concise C API that can be invoked from + # The compiled TVM function exposes a concise C API that can be invoked from # any language. # # We provide a minimal array API in python to aid quick testing and prototyping. @@ -512,7 +512,7 @@ def evaluate_addition(func, target, optimization, log): # before it moves on to the next stage. # # A complete description of these primitives can be found in the -# [Schedule Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html) docs page. +# :ref:`Schedule Primitives ` docs page. ################################################################################ # Example 2: Manually Optimizing Matrix Multiplication with TE diff --git a/gallery/tutorial/tvmc_command_line_driver.py b/gallery/tutorial/tvmc_command_line_driver.py index ea3254054ecf..facb978cea67 100644 --- a/gallery/tutorial/tvmc_command_line_driver.py +++ b/gallery/tutorial/tvmc_command_line_driver.py @@ -154,11 +154,9 @@ # Specifying the correct target (option ``--target``) can have a huge # impact on the performance of the compiled module, as it can take # advantage of hardware features available on the target. For more -# information, please refer to `Auto-tuning a convolutional network -# for x86 CPU `_. -# We recommend identifying which CPU you are running, along with optional features, -# and set the target appropriately. -# +# information, please refer to :ref:`Auto-tuning a convolutional network for +# x86 CPU `. We recommend identifying which CPU you are +# running, along with optional features, and set the target appropriately. ################################################################################ # Running the Model from The Compiled Module with TVMC @@ -176,10 +174,10 @@ # data types. For this reason, most models require some pre and # post-processing, to ensure the input is valid and to interpret the output. # TVMC has adopted NumPy's ``.npz`` format for both input and output data. This -# is a well-supported NumPy format to serialize multiple arrays into a file +# is a well-supported NumPy format to serialize multiple arrays into a file. # # As input for this tutorial, we will use the image of a cat, but you can feel -# free to substitute image for any of your choosing. +# free to substitute this image for any of your choosing. # # .. image:: https://s3.amazonaws.com/model-server/inputs/kitten.jpg # :height: 224px @@ -199,8 +197,8 @@ # requirement for the script. # # .. code-block:: python -# :caption: preprocess.py -# :name: preprocess.py +# :caption: preprocess.py +# :name: preprocess.py # # #!python ./preprocess.py # from tvm.contrib.download import download_testdata diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 418d532fdd5f..45a938247cc8 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -30,6 +30,7 @@ #define TVM_DRIVER_DRIVER_API_H_ #include +#include #include #include #include @@ -43,6 +44,34 @@ #include namespace tvm { +using tvm::transform::Pass; + +/*! + * \brief Configures and returns the composite Pass for the fused module (pre split) that contains + * device and host code. + * \param mixed_mod The original mixed module. + * \param target The device Target. + * \return The composite Pass for the fused module. +// */ +TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target); + +/*! + * \brief Configures and returns the composite Pass for the device Target after device/host from + * mixed module. + * \param mixed_mod The optimized mixed module. + * \param target The device Target. + * \return The composite Pass for the device module. + */ +TVM_DLL transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target); + +/*! + * \brief Configures and returns the composite Pass for the host Target after device/host from mixed + * module. + * \param mixed_mod The optimized mixed module. + * \param target_host The host Target. + * \return The composite Pass for the host module. + */ +TVM_DLL transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host); /*! * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) @@ -136,6 +165,7 @@ TVM_DLL runtime::Module build(const Map& input, const Target& * \return The built module that contains code for different processors. */ TVM_DLL runtime::Module build(const Map& input, const Target& target_host); + } // namespace tvm #endif // TVM_DRIVER_DRIVER_API_H_ diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 83b4ddaead43..3652a09e9168 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -76,6 +76,22 @@ struct TopKAttrs : public tvm::AttrsNode { } }; +struct SearchSortedAttrs : public tvm::AttrsNode { + bool right; + DataType dtype; + + TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs") { + TVM_ATTR_FIELD(right).set_default(false).describe( + "Controls which index is returned if a value lands exactly on one of sorted values. If " + " false, the index of the first suitable location found is given. If true, return the " + "last such index. If there is no suitable index, return either 0 or N (where N is the " + "size of the innermost dimension)."); + TVM_ATTR_FIELD(dtype) + .set_default(DataType::Int(32)) + .describe("Data type of the output indices."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ALGORITHM_H_ diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index de60deb9cccb..26d2c72c824d 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -686,6 +686,7 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { Array padding; Array dilation; tvm::String layout; + tvm::String out_layout; bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") { @@ -709,6 +710,13 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Pooling is applied on the 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( "When true, will use ceil instead of floor to compute the output shape."); } @@ -721,6 +729,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { Array padding; Array dilation; tvm::String layout; + tvm::String out_layout; bool ceil_mode; bool count_include_pad; @@ -745,6 +754,13 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Pooling is applied on the 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( "When true, will use ceil instead of floor to compute the output shape."); TVM_ATTR_FIELD(count_include_pad) @@ -756,6 +772,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { /*! \brief Attributes for global pool operator */ struct GlobalPool2DAttrs : public tvm::AttrsNode { tvm::String layout; + tvm::String out_layout; TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") { TVM_ATTR_FIELD(layout).set_default("NCHW").describe( @@ -763,6 +780,13 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Pooling is applied on the 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); } }; @@ -770,6 +794,7 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode { struct AdaptivePool1DAttrs : public tvm::AttrsNode { Array output_size; std::string layout; + tvm::String out_layout; TVM_DECLARE_ATTRS(AdaptivePool1DAttrs, "relay.attrs.AdaptivePool1DAttrs") { TVM_ATTR_FIELD(output_size).set_default(Array({})).describe("Output width."); @@ -778,6 +803,13 @@ struct AdaptivePool1DAttrs : public tvm::AttrsNode { "'N', 'C', 'W' stands for batch, channel, and width" "dimensions respectively. Pooling is applied on the" "'W' dimension."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the" + "'W' dimension."); } }; @@ -785,6 +817,7 @@ struct AdaptivePool1DAttrs : public tvm::AttrsNode { struct AdaptivePool2DAttrs : public tvm::AttrsNode { Array output_size; std::string layout; + tvm::String out_layout; TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relay.attrs.AdaptivePool2DAttrs") { TVM_ATTR_FIELD(output_size) @@ -795,6 +828,13 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Pooling is applied on the 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); } }; @@ -802,6 +842,7 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode { struct AdaptivePool3DAttrs : public tvm::AttrsNode { Array output_size; std::string layout; + tvm::String out_layout; TVM_DECLARE_ATTRS(AdaptivePool3DAttrs, "relay.attrs.AdaptivePool3DAttrs") { TVM_ATTR_FIELD(output_size) @@ -812,6 +853,13 @@ struct AdaptivePool3DAttrs : public tvm::AttrsNode { "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" "dimensions respectively. Pooling is applied on 'D', 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on 'D', 'H' and" + "'W' dimensions."); } }; @@ -822,6 +870,7 @@ struct MaxPool1DAttrs : public tvm::AttrsNode { Array dilation; Array padding; std::string layout; + tvm::String out_layout; bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool1DAttrs, "relay.attrs.MaxPool1DAttrs") { @@ -844,6 +893,12 @@ struct MaxPool1DAttrs : public tvm::AttrsNode { "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." "'N', 'C', 'W' stands for batch, channel, and width" "dimensions respectively. Pooling is applied on the 'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the 'W' dimensions."); TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( "When true, will use ceil instead of floor to compute the output shape."); } @@ -856,6 +911,7 @@ struct AvgPool1DAttrs : public tvm::AttrsNode { Array dilation; Array padding; std::string layout; + tvm::String out_layout; bool ceil_mode; bool count_include_pad; @@ -879,6 +935,12 @@ struct AvgPool1DAttrs : public tvm::AttrsNode { "Dimension ordering of input data. Can be 'NCW', 'NHC', etc." "'N', 'C', 'W' stands for batch, channel, and width" "dimensions respectively. Pooling is applied on the 'W' dimension."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCW', 'NHC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the 'W' dimension."); TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( "When true, will use ceil instead of floor to compute the output shape."); TVM_ATTR_FIELD(count_include_pad) @@ -894,6 +956,7 @@ struct MaxPool3DAttrs : public tvm::AttrsNode { Array dilation; Array padding; std::string layout; + tvm::String out_layout; bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool3DAttrs, "relay.attrs.MaxPool3DAttrs") { @@ -917,6 +980,13 @@ struct MaxPool3DAttrs : public tvm::AttrsNode { "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" "dimensions respectively. Pooling is applied on the 'D', 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on the 'D', 'H' and" + "'W' dimensions."); TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( "When true, will use ceil instead of floor to compute the output shape."); } @@ -929,6 +999,7 @@ struct AvgPool3DAttrs : public tvm::AttrsNode { Array dilation; Array padding; std::string layout; + tvm::String out_layout; bool ceil_mode; bool count_include_pad; @@ -953,6 +1024,13 @@ struct AvgPool3DAttrs : public tvm::AttrsNode { "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" "dimensions respectively. Pooling is applied on the 'D', 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on the 'D', 'H' and" + "'W' dimensions."); TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( "When true, will use ceil instead of floor to compute the output shape."); TVM_ATTR_FIELD(count_include_pad) diff --git a/include/tvm/runtime/crt/graph_executor.h b/include/tvm/runtime/crt/graph_executor.h index eb68ff56d230..1353d8e06e6b 100644 --- a/include/tvm/runtime/crt/graph_executor.h +++ b/include/tvm/runtime/crt/graph_executor.h @@ -36,7 +36,7 @@ struct TVMModule; /*! \brief operator attributes about tvm op */ typedef struct TVMOpParam { - char func_name[120]; + char func_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; uint32_t num_inputs; uint32_t num_outputs; uint32_t flatten_data; diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 7ee140622bfc..366f4f1deed1 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -198,13 +198,20 @@ class ReportNode : public Object { */ String AsCSV() const; /*! \brief Create a human readable table of profiling metrics. - * \param aggregate Whether or not to join multiple calls to the same op into a single line. - * \param sort Whether or not to sort call frames by descending duration. If - * false and if `aggregate` is false, frames will be sorted by order of - * appearance in the program. Order is undefined if `sort` is false and - * `aggregate` is true. + * + * \param aggregate Whether or not to join multiple calls to the + * same op into a single line. + * + * \param sort Whether or not to sort call frames by descending + * duration. If false and if `aggregate` is false, frames will + * be sorted by order of appearance in the program. Order is + * undefined if `sort` is false and `aggregate` is true. + * + * \param compute_col_sums Whether or not to include sum totals for + * the Count, Duation, and Percent columns. + * */ - String AsTable(bool sort = true, bool aggregate = true) const; + String AsTable(bool sort = true, bool aggregate = true, bool compute_col_sums = true) const; /*! \brief Convert this report to JSON. * * Output JSON will be of this format: @@ -452,11 +459,23 @@ class CountNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(CountNode, Object); }; -/*! \brief String representation of an array or NDArray shapes +/*! \brief String representation of an array of NDArray shapes * \param shapes Array of NDArrays to get the shapes of. * \return A textual representation of the shapes. For example: `float32[2], int64[1, 2]`. */ String ShapeString(const std::vector& shapes); +/*! \brief String representation of shape encoded as an NDArray + * \param shape NDArray containing the shape. + * \param dtype The dtype of the shape. + * \return A textual representation of the shape. For example: `float32[2]`. + */ +String ShapeString(NDArray shape, DLDataType dtype); +/*! \brief String representation of a shape encoded as a vector + * \param shape Shape as a vector of integers. + * \param dtype The dtype of the shape. + * \return A textual representation of the shape. For example: `float32[2]`. + */ +String ShapeString(const std::vector& shape, DLDataType dtype); } // namespace profiling } // namespace runtime diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 831336b9dbfe..039b1894d7c4 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -198,14 +198,14 @@ class VirtualMachine : public runtime::ModuleNode { * \param reg The register to read from. * \return The read object. */ - inline ObjectRef ReadRegister(RegName reg) const; + ObjectRef ReadRegister(RegName reg) const; /*! * \brief Read a VM register and cast it to int32_t * \param reg The register to read from. * \return The read scalar. */ - inline int64_t LoadScalarInt(RegName reg) const; + int64_t LoadScalarInt(RegName reg) const; /*! * \brief Invoke a VM function. @@ -268,6 +268,22 @@ class VirtualMachine : public runtime::ModuleNode { */ void SetInput(std::string name, TVMArgs args, int offset); + /*! + * \brief Internal hook for profiling the start of an op. + * + * This hook is only called on certain ops that are likely to take a + * significant amount of runtime (normally because they alloc or transfer to + * device). + * + * \param instr Instruction that will be executed after this hook fires + */ + virtual void OpStartHook(Instruction instr); + + /*! + * \brief Internal hook for profiling the end of an op. + */ + virtual void OpStopHook(); + protected: /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs_; diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 8a2bbcbd0121..e802a3088d2d 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -24,6 +24,7 @@ #ifndef TVM_TARGET_TARGET_KIND_H_ #define TVM_TARGET_TARGET_KIND_H_ +#include #include #include @@ -33,6 +34,33 @@ #include namespace tvm { + +class Target; + +/*! + * \brief RelayToTIR tvm::transform::Pass specific to a TargetKind + * + * Called before the default lowering passes. + * + * \param mod The module that an optimization pass runs on. + * \param pass_ctx The pass context that can provide information for the optimization. + * + * \return The transformed module. + */ +using FTVMRelayToTIR = transform::Pass; + +/*! + * \brief TIRToRuntime conversion specific to a TargetKind + * + * This function is responsible for scanning an IRModule for appropriate Target-specific functions + and generating a Runtime module representing the compiled output + * + * \param ir_module Unified IRModule + * \param target Target to filter on or retrieve arguments from + * \return Runtime Module containing compiled functions + */ +using FTVMTIRToRuntime = runtime::TypedPackedFunc; + namespace detail { template struct ValueTypeInfoMaker; @@ -201,6 +229,12 @@ class TargetKindRegEntry { * \return The entry names. */ TVM_DLL static Array ListTargetKinds(); + /*! + * \brief Get all supported option names and types for a given Target kind. + * \return Map of option name to type + */ + TVM_DLL static Map ListTargetKindOptions(const TargetKind& kind); + /*! * \brief Register or get a new entry. * \param target_kind_name The name of the TargetKind. diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 23057f7140e4..e4a3d3d1e21b 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -199,9 +199,10 @@ class LinkedParam : public ObjectRef { * def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: * A = T.match_buffer(a, (m, n), "float32") * B = T.match_buffer(b, (m, n), "float32") - * - * with T.block([m, n], "") as [vi, vj]: - * B[vi, vj] = A[vi, vj] + * for i, j in T.grid(m, n): + * with T.block(): + * vi, vj = T.axis.remap("SS", [i, j]) + * B[vi, vj] = A[vi, vj] * \endcode * * Then we can make it specialized with given shapes or buffers. @@ -218,9 +219,10 @@ class LinkedParam : public ObjectRef { * def mem_copy_16_16(a: T.handle, b: T.handle) -> None: * A = T.match_buffer(a, (16, 16), "float32") * B = T.match_buffer(b, (16, 16), "float32") - * - * with T.block([16, 16], "") as [vi, vj]: - * B[vi, vj] = A[vi, vj] + * for i, j in T.grid(16, 16): + * with T.block(): + * vi, vj = T.axis.remap("SS", [i, j]) + * B[vi, vj] = A[vi, vj] * \endcode */ PrimFunc Specialize(PrimFunc func, const Map& param_map); diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5cd860b8e929..4f5772822d9e 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1078,9 +1078,9 @@ class MatchBufferRegion : public ObjectRef { * \note Block's body is parameterized by iter vars. * \code * - * with T.block([extent0, extent1, ...], name) as [v0, v1, ...]: - * T.bind(v0, value0) - * T.bind(v1, value1) + * with T.block(name): + * v0 = T.axis.S(domain, value0) + * v1 = T.axis.R(domain, value1) * ... * T.reads([buffer0[start:end, ...], ...]) * T.writes([buffer1[start:end, ...], ...]) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index e94b966bc0fc..e6b0af9773d9 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -388,7 +388,7 @@ TVM_DLL Pass ConvertBlocksToOpaque(); * \code * * for i in range(0, 16): - * with T.block([]): + * with T.block(): * B = T.alloc_buffer(16, 16) * for j in range(0, 16): * B[i, j] = A[i, j] + 1 @@ -404,7 +404,7 @@ TVM_DLL Pass ConvertBlocksToOpaque(); * \code * * for i in range(0, 16): - * with T.block([]): + * with T.block(): * B = T.alloc_buffer(1, 16) * for j in range(0, 16): * B[0, j] = A[i, j] + 1 @@ -463,6 +463,15 @@ TVM_DLL Pass UnifyThreadBinding(); */ TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); +/*! + * \brief This pass is post-scheduling pass to convert all + * Parallel For loops to Serial ones. This is run + * to attain lesser memory and/or executor/backend + * does not support parallel launch of For loops. + * \return The pass. + */ +TVM_DLL Pass ConvertForLoopsToSerial(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 65c5c12a701b..40a0d1ab2f74 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -109,6 +109,12 @@ class Var : public PrimExpr { * \return the new Var copy */ TVM_DLL Var copy_with_suffix(const String& suffix) const; + /*! + * \brief Make a new copy of the variable with specified dtype + * \param dtype The specified dtype + * \return The new variable + */ + TVM_DLL Var copy_with_dtype(DataType dtype) const; /*! * \brief Get pointer to the internal value. diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 8d1a49a4cc5f..3df9caf55d5c 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1233,8 +1233,10 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, } ICHECK_GE(axis, 0); ICHECK_LT(axis, ndim_d); - size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); - ICHECK_GE(indices_dim_i, 1); + if (indices->shape[axis].as()) { + size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); + ICHECK_GE(indices_dim_i, 1); + } ICHECK(indices->dtype.is_int()); Array out_shape; diff --git a/python/gen_requirements.py b/python/gen_requirements.py index fa94d6a64130..e9f3772ee733 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -198,6 +198,7 @@ "sphinx_autodoc_annotation", "sphinx_gallery", "sphinx_rtd_theme", + "types-psutil", ], ), ), @@ -250,7 +251,7 @@ ("sphinx_autodoc_annotation", None), ("sphinx_gallery", None), ("sphinx_rtd_theme", None), - ("synr", "==0.4.1"), + ("synr", "==0.5.0"), ("tensorflow", None), ("tensorflow-estimator", None), ("tflite", None), diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 0eacd1a1f667..6f35e021daf8 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -58,7 +58,6 @@ def call_all_topi_funcs(mod, params, target, opt_level=3): opt_level=opt_level, config={ "relay.backend.use_auto_scheduler": True, - "relay.backend.disable_compile_engine_cache": True, }, disabled_pass={"AutoSchedulerLayoutRewrite"}, ): @@ -165,7 +164,8 @@ class TracingMode: """Two modes for tracing""" EXTRACT_TASK = 0 # trace all topi calls to extract tasks - EXTRACT_COMPLEX_TASK_ONLY = 1 # same as EXTRACT_TASK but ignore the task without complex ops + # same as EXTRACT_TASK but ignore the task without complex ops + EXTRACT_COMPLEX_TASK_ONLY = 1 PREPARE_LAYOUT_REWRITE = 2 # trace topi calls to prepare layout rewrite diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index 8d2591dce50b..f73c65fbd1d8 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -31,7 +31,6 @@ import tvm._ffi from tvm.target import Target -from tvm.te import schedule from tvm.driver import build_module @@ -39,13 +38,12 @@ def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ - binds, _ = build_module.get_binds(args, compact=False, binds=binds) sch = sch.normalize() # Phase 0 - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds, True) - func = schedule.SchedulePostProcToPrimFunc(args, stmt, None) - mod = tvm.IRModule.from_expr(func._move()) + context = tvm.transform.PassContext(config={"tir.debug_keep_trivial_loop": True}) + with context: + mod = build_module.schedule_to_module(sch, args, binds=binds) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) assert simple_mode diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index beb1aa03090d..25d56cf8cf02 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -166,7 +166,7 @@ def __init__( if isinstance(graph, relay.function.Function): node_dict = {} graph = bind_inputs(graph, input_shapes, dtype) - expr2graph(graph, self._target_ops, node_dict, self._node_list) + expr2graph(graph, self._target_ops, node_dict, self._node_list, target) else: raise RuntimeError("Unsupported graph type: %s" % str(type(graph))) diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index f61d34284e01..7299875bf28d 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -17,6 +17,7 @@ # pylint: disable=too-many-locals,too-many-statements,too-many-branches,protected-access """API for graph traversing.""" import threading +import re import tvm from tvm import relay, autotvm @@ -30,7 +31,7 @@ from .._base import OPT_OUT_OP -def expr2graph(expr, target_ops, node_dict, node_list): +def expr2graph(expr, target_ops, node_dict, node_list, tvm_target): """Convert relay expr to graph data structure and fetch workloads of target operators. @@ -50,6 +51,9 @@ def expr2graph(expr, target_ops, node_dict, node_list): Each node will be stored as a dictionary in the format of {"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type], "name": str, "workloads": [tuple], "topi_op": [function]} + + tvm_target : tvm.target + The TVM target object. """ # TODO(@kevinthesun, @icemelon9): Currently graph tuning pass relies on the fact # that # autotvm tasks == # ops. But this won't be true after having relay op @@ -58,12 +62,12 @@ def expr2graph(expr, target_ops, node_dict, node_list): env.reset(target_ops) # pylint: disable=not-context-manager with env: - _expr2graph_impl(expr, target_ops, node_dict, node_list) + _expr2graph_impl(expr, target_ops, node_dict, node_list, tvm_target) task_pos = 0 for node_entry in node_list: if node_entry["op"] in target_ops: task_name, args = env.task_collection[task_pos] - task = autotvm.task.create(task_name, args, target="llvm") + task = autotvm.task.create(task_name, args, target=tvm_target) node_entry["workloads"] = [task.workload] node_entry["topi_op"] = [task_name] task_pos += 1 @@ -77,7 +81,18 @@ def _infer_type(node): return entry if isinstance(node, relay.Function) else entry.body -def _expr2graph_impl(expr, target_ops, node_dict, node_list): +def _replace_device_with_tracing(target): + """This is to replace -device=XXX with -device=tracing in the tvm_target string. + It is a stand-along function for testability. + We need to have device=tracing in order to fetch the workloads, it is not used + for anything beyond that so it is safe to override the device here only.""" + target = str(target) + if "-device" in target: + return re.sub("-device=[^\\-$]+", "-device=tracing ", target).strip(" ") + return target + " -device=tracing" + + +def _expr2graph_impl(expr, target_ops, node_dict, node_list, tvm_target): """Implementation to convert relay expr to graph data structure""" def _traverse_expr(node): @@ -127,9 +142,10 @@ def _traverse_expr(node): params.append(free_var) call = relay.Call(node.op, params, node.attrs) mod = tvm.IRModule.from_expr(relay.Function(params, call)) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() + tracing_target = _replace_device_with_tracing(tvm_target) build_thread = threading.Thread( - target=relay.build, args=(mod, "llvm -device=tracing", None, None) + target=relay.build, args=(mod, tracing_target, None, None) ) build_thread.start() build_thread.join() @@ -139,7 +155,7 @@ def _traverse_expr(node): elif isinstance(node, Function): # Ignore root node since it equals to input function expression if node != expr: - _expr2graph_impl(node, target_ops, node_dict, node_list) + _expr2graph_impl(node, target_ops, node_dict, node_list, tvm_target) return elif isinstance(node, TupleGetItem): in_node_idx = node_dict[node.tuple_value] diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 714dd540d3ab..4716116a1b83 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -127,12 +127,12 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No assert isinstance( mod, tvm.IRModule ), "only support relay Module or Function to be tuned" - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() # wrap build call in thread to avoid multiprocessing problems build_thread = threading.Thread(target=_lower, args=(mod, target, param)) build_thread.start() build_thread.join() - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() # Clear the warning message cache in FallbackContext if isinstance(DispatchContext.current, FallbackContext): DispatchContext.current.memory = {} diff --git a/python/tvm/contrib/pipeline_executor.py b/python/tvm/contrib/pipeline_executor.py index 36c03891d210..37b9fed8eb91 100644 --- a/python/tvm/contrib/pipeline_executor.py +++ b/python/tvm/contrib/pipeline_executor.py @@ -16,6 +16,7 @@ # under the License. """Pipeline executor that executes a series of modules in a pipeline fashion.""" import json +import os import tvm._ffi from tvm import relay from tvm.relay.transform import InferType @@ -47,13 +48,13 @@ def build(pipe_configs): ret: PipelineExecutorFactoryModule Common interface for pipeline executor factory modules. """ - mods = {} + libs = {} mod_n_configs = pipe_configs.get_config() config_len = len(mod_n_configs) string_config = [{} for _ in range(config_len)] for ir_mod, mod_config in mod_n_configs.items(): mconf = mod_config["pipeline"].copy() - mod_idx = mconf["mod_idx"] - 1 + mod_idx = mconf["mod_idx"] dev = mod_config["dev"] target = mod_config["target"] build_func = relay.build @@ -61,7 +62,7 @@ def build(pipe_configs): if "build" in mod_config and mod_config["build"]: build_func = mod_config["build"] - mod = build_func( + lib = build_func( ir_mod, target, params=mod_config["params"], @@ -72,9 +73,9 @@ def build(pipe_configs): mconf["dev"] = "{},{}".format(dev.device_type, dev.device_id) # Create a pipeline configuration. string_config[mod_idx] = mconf - mods[mod] = {"dev": dev} + libs[mod_idx] = {"lib": lib, "dev": dev} - return PipelineExecutorFactoryModule(mods, string_config) + return PipelineExecutorFactoryModule(libs, string_config) class PipelineModule(object): @@ -82,12 +83,59 @@ class PipelineModule(object): Parameters ---------- - module : PipelineExecutorFactoryModule - Common interface for pipeline executor factory modules. + module : Union[PipelineExecutorFactoryModule, Module] + Common interface for pipeline executor factory modules or Module. """ def __init__(self, module): - self.module = module.module + if isinstance(module, PipelineExecutorFactoryModule): + self.module = module.module + else: + self.module = module + # Get the packed functions from the pipeline executor. + self._get_num_outputs = self.module["get_num_outputs"] + + @property + def num_outputs(self): + """Get the number of outputs. + Returns + ------- + count : int + The number of outputs. + """ + return self._get_num_outputs() + + @staticmethod + def load_library(config_file_name): + """Import files to create a pipeline executor. + + Parameters + ---------- + config_file_name : str + Path and name of the configuration file, the configuration file contains the + disk path of the parameter file, library file, and JSON file. + """ + with open(config_file_name, "r") as file_handle: + config = file_handle.read() + config = json.loads(config) + if "load_config" not in config or "pipeline_config" not in config: + raise RuntimeError( + '"load_config" or "pipeline_config" is missing in %s' % config_file_name + ) + + # The config file used to load library, prameters, and JSON files. + with open(config["load_config"], "r") as file_handle: + load_config = file_handle.read() + + # The config file used to load pipeline compute config. + with open(config["pipeline_config"], "r") as file_handle: + pipeline_config = file_handle.read() + + # Load a PipelineExecutor from the disk files. + load_library = tvm._ffi.get_global_func("tvm.pipeline_executor.load", allow_missing=False) + module = load_library(load_config, pipeline_config) + + return PipelineModule(module) class PipelineConfig(object): @@ -139,13 +187,14 @@ def get_owner_idx(self): if isinstance(self.io_owner, PipelineConfig.ModuleWrapper): return self.io_owner.idx - return 0 + return -1 - def is_global_interface(self): - """The global interface is the interface visible to the caller which use a pipeline - executor, the global input interface is responsible for passing parameters to the - internal module interface, and the global output interface is responsible for - outputting the results computed by the pipeline executor to a caller. + def is_pipeline_executor_interface(self): + """The pipeline interface is used to interact with the caller. There are two types + of interfaces, one is 'input' another is 'output'. The pipeline input interface + is responsible for passing parameters to the internal module interface, and the + pipeline output interface is responsible for outputting the results computed by + the pipeline executor to the caller. """ return not isinstance(self.io_owner, PipelineConfig.ModuleWrapper) @@ -182,9 +231,9 @@ def check_dag_acyclic(self, start, inputs): def connect(self, binding): """Connect the current interface to the destination interface. - Correct connections are as follows: 1. global input connected to module input, - 2. module output connected to global output, 3. module output connected to - module input. + Correct connections are as follows: 1. the pipeline input connected to a module input, + 2. the module output connected to a pipeline output, 3. the module output connected to + a module input. Parameters ---------- @@ -196,31 +245,31 @@ def connect(self, binding): if self.io_owner == binding.io_owner: raise RuntimeError(f"Can not bind itself.") - if not self.is_global_interface() and self.io_type == "input": + if not self.is_pipeline_executor_interface() and self.io_type == "input": raise RuntimeError(f"Module can only bind from output interface!") if ( - not self.is_global_interface() - and not binding.is_global_interface() + not self.is_pipeline_executor_interface() + and not binding.is_pipeline_executor_interface() and binding.io_type == "output" ): raise RuntimeError(f"Can not bind module output with another module output!") if ( - not self.is_global_interface() - and binding.is_global_interface() + not self.is_pipeline_executor_interface() + and binding.is_pipeline_executor_interface() and binding.io_type == "input" ): - raise RuntimeError(f"Can not bind module output with global input!") + raise RuntimeError(f"Can not bind module output with pipeline input!") - if self.is_global_interface() and self.io_type == "output": + if self.is_pipeline_executor_interface() and self.io_type == "output": raise RuntimeError(f"Global output can not be used as binding start point.") - if self.is_global_interface() and binding.io_type != "input": + if self.is_pipeline_executor_interface() and binding.io_type != "input": raise RuntimeError(f"Global input can only bind with module input.") self.bindings.append(binding) - if not self.is_global_interface(): + if not self.is_pipeline_executor_interface(): # Check whether the data types of the source and destination are the same. if ( isinstance(binding.io_owner, PipelineConfig.ModuleWrapper) @@ -431,13 +480,16 @@ def get_config(self): for dep in binding.bindings: dep_item = {} _, dname = dep.get_name() - dep_item["mod_idx"] = dep.get_owner_idx() - dep_item["input_name"] = dname + if dep.is_pipeline_executor_interface(): + dep_item["global_output_index"] = int(dname) + else: + dep_item["mod_idx"] = dep.get_owner_idx() + dep_item["input_name"] = dname dep_conf.append(dep_item) # The value of ouput_idx start from 0. output["output_idx"] = int(binding.name) - output["dependent"] = dep_conf + output["dependencies"] = dep_conf output_conf.append(output) mconf["mod_idx"] = module.idx @@ -472,7 +524,7 @@ def dag_topology_sort(self): mlist += temp_list for mod, i in zip(mlist, range(len(mlist))): - self.mod_wrapper[mod].set_idx_name(i + 1) + self.mod_wrapper[mod].set_idx_name(i) def get_mod_idx(self, mod): # Return the module index. @@ -502,16 +554,13 @@ class PipelineExecutorFactoryModule(object): """ def __init__(self, pipeline_mods, mods_config): - mods, config = self.graph_executor_create(pipeline_mods, mods_config) - assert ( - pipeline_executor_enabled() - ), "Pipeline executor is not enabled. Please \ - re-build TVM with USE_PIPELINE_EXECUTOR=ON" - pipeline_create = tvm._ffi.get_global_func( + self.pipeline_mods = pipeline_mods + self.mods_config = mods_config + graph_executors, config = self.graph_executor_create(pipeline_mods, mods_config) + self.pipeline_create = tvm._ffi.get_global_func( "tvm.pipeline_executor.create", allow_missing=False ) - assert pipeline_create - self.module = pipeline_create(mods, config) + self.module = self.pipeline_create(graph_executors, config) def graph_executor_create(self, pipeline_mods, mod_config): """Create graph_executor list and return configuration as a json string. @@ -532,12 +581,70 @@ def graph_executor_create(self, pipeline_mods, mod_config): mod_config : str The Modudle configuration. """ + # Should store modules in the list named 'mods' in index order. + mods = [None for _ in range(len(pipeline_mods))] + for lib_index in pipeline_mods: + pipeline_lib = pipeline_mods[lib_index]["lib"] + dev = pipeline_mods[lib_index]["dev"] + lib = graph_executor.GraphModule(pipeline_lib["default"](dev)) + # Return a module list sorted by lib_index. + mods[lib_index] = lib.module + + return mods, json.dumps(mod_config) + + def export_library(self, directory_path): + """Export the pipeline executor into disk files. - mods = [] - for pipeline_mod in pipeline_mods: - mod = graph_executor.GraphModule( - pipeline_mod["default"](pipeline_mods[pipeline_mod]["dev"]) + Parameters + ---------- + directory_path : str + Export the files to this directory. + """ + if not self.pipeline_mods: + raise RuntimeError(f"The pipeline executor has not been initialized.") + + # Check if the directory_path exists. + if not os.path.exists(directory_path): + raise RuntimeError(f"The directory {directory_path} does not exist.") + # Create an load configuration. + load_config_file_name = "{}/load_config".format(directory_path) + pipeline_config_file_name = "{}/pipeline_config".format(directory_path) + config = {} + config["load_config"] = load_config_file_name + config["pipeline_config"] = pipeline_config_file_name + load_config = [] + # Export the library, JSON, and parameter into files, then export these files path + # into a configuration file. + for lib_index in self.pipeline_mods: + mconfig = {} + mconfig["mod_idx"] = lib_index + mconfig["lib_name"] = "{}/lib{}.so".format(directory_path, lib_index) + mconfig["json_name"] = "{}/json{}".format(directory_path, lib_index) + mconfig["params_name"] = "{}/params{}".format(directory_path, lib_index) + mconfig["dev"] = "{},{}".format( + self.pipeline_mods[lib_index]["dev"].device_type, + self.pipeline_mods[lib_index]["dev"].device_id, ) - mods.append(mod.module) - return mods, json.dumps(mod_config) + # Get the graph, lib, and parameters from GraphExecutorFactoryModule. + graph, lib, params = self.pipeline_mods[lib_index]["lib"] + # Export the lib, graph, and parameters to disk. + lib.export_library(mconfig["lib_name"]) + with open(mconfig["json_name"], "w") as file_handle: + file_handle.write(graph) + with open(mconfig["params_name"], "wb") as file_handle: + file_handle.write(relay.save_param_dict(params)) + + load_config.append(mconfig) + + with open(load_config_file_name, "w") as file_handle: + json.dump(load_config, file_handle) + + with open(pipeline_config_file_name, "w") as file_handle: + json.dump(self.mods_config, file_handle) + + config_file_name = "{}/config".format(directory_path) + with open(config_file_name, "w") as file_handle: + json.dump(config, file_handle) + + return config_file_name diff --git a/python/tvm/contrib/xcode.py b/python/tvm/contrib/xcode.py index c44a2fe4a136..6d5e10f611db 100644 --- a/python/tvm/contrib/xcode.py +++ b/python/tvm/contrib/xcode.py @@ -45,7 +45,23 @@ def xcrun(cmd): return out.strip() -def create_dylib(output, objects, arch, sdk="macosx"): +def __get_min_os_version(sdk): + if sdk in ("macosx", "iphonesimulator"): + return None + if sdk == "iphoneos": + return "13.0" + raise RuntimeError("Unsupported sdk: %s" % sdk) + + +def __get_min_os_version_cmd(sdk, min_os_version): + if min_os_version is None: + min_os_version = __get_min_os_version(sdk) + if min_os_version is not None: + return "-mios-version-min=" + min_os_version + return "" + + +def create_dylib(output, objects, arch, sdk="macosx", min_os_version=None): """Create dynamic library. Parameters @@ -71,6 +87,7 @@ def create_dylib(output, objects, arch, sdk="macosx"): cmd += ["-dynamiclib"] cmd += ["-arch", arch] cmd += ["-isysroot", sdk_path] + cmd += [__get_min_os_version_cmd(sdk, min_os_version)] cmd += ["-o", output] if isinstance(objects, str): cmd += [objects] @@ -90,7 +107,7 @@ def create_dylib(output, objects, arch, sdk="macosx"): create_dylib.output_format = "dylib" -def compile_metal(code, path_target=None, sdk="macosx"): +def compile_metal(code, path_target=None, sdk="macosx", min_os_version=None): """Compile metal with CLI tool from env. Parameters @@ -123,7 +140,14 @@ def compile_metal(code, path_target=None, sdk="macosx"): # # xcrun -sdk macosx metal -c MyLibrary.metal -o MyLibrary.air # xcrun -sdk macosx metallib MyLibrary.air -o MyLibrary.metallib - cmd1 = ["xcrun", "-sdk", sdk, "metal", "-O3"] + min_target = __get_min_os_version_cmd(sdk, min_os_version) + if sdk == "macosx": + language_version = "-std=macos-metal2.3" + elif sdk in ("iphoneos", "iphonesimulator"): + language_version = "-std=ios-metal2.3" + else: + raise RuntimeError("Unsupported sdk: %s" % sdk) + cmd1 = ["xcrun", "-sdk", sdk, "metal", language_version, min_target, "-O3"] cmd1 += ["-c", temp_code, "-o", temp_ir] cmd2 = ["xcrun", "-sdk", sdk, "metallib"] cmd2 += [temp_ir, "-o", file_target] diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a7ebc00c315f..5ec44c6d6ed1 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -16,27 +16,23 @@ # under the License. # pylint: disable=invalid-name -"""The build utils in python. -""" +"""The build utils in python.""" from typing import Union, Optional, List, Mapping -import warnings import tvm.tir from tvm.runtime import Module from tvm.runtime import ndarray from tvm.ir import container -from tvm.ir import CallingConv from tvm.tir import PrimFunc from tvm.ir.module import IRModule -from tvm.ir.transform import PassContext -from tvm.target import codegen from tvm.te import tensor from tvm.te import schedule from tvm.target import Target from tvm.tir.buffer import Buffer from tvm.tir.expr import Var +from tvm.driver import _ffi_api as _driver_ffi from . import _ffi_api as ffi @@ -71,6 +67,11 @@ def schedule_to_module( binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, ) -> IRModule: """According to the given schedule, form a function. + + This is a low-level function intended for testing purposes, and + does not apply any optimization passes. In general, `tvm.lower` + and `tvm.build` should be used instead. + Parameters ---------- sch : tvm.te.schedule.Schedule @@ -104,8 +105,8 @@ def lower( args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] The argument lists to the function for TE schedule. - It should be None if we want to lower TensorIR. + It should be None if we want to lower TensorIR. name : str The name of the result function. @@ -132,98 +133,6 @@ def lower( raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) -def _build_for_device(input_mod, target, target_host): - """Build the lowered functions for a device with the given compilation - target. - - Parameters - ---------- - input_mod : IRModule - The schedule to be built. - - target : str or :any:`tvm.target.Target` - The target and option of the compilation. - - target_host : str or :any:`tvm.target.Target` - The host compilation target. - - Returns - ------- - fhost : IRModule - The host IRModule. - - mdev : tvm.module - A module that contains device code. - """ - target, target_host = Target.check_and_update_host_consist(target, target_host) - device_type = ndarray.device(target.kind.name, 0).device_type - - mod_mixed = input_mod - mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) - - opt_mixed = [ - tvm.tir.transform.VerifyMemory(), - tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), - ] - if len(mod_mixed.functions) == 1: - opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))] - - if PassContext.current().config.get("tir.detect_global_barrier", False): - opt_mixed += [tvm.tir.transform.ThreadSync("global")] - opt_mixed += [ - tvm.tir.transform.ThreadSync("shared"), - tvm.tir.transform.ThreadSync("warp"), - tvm.tir.transform.InferFragment(), - tvm.tir.transform.LowerThreadAllreduce(), - tvm.tir.transform.MakePackedAPI(), - tvm.tir.transform.SplitHostDevice(), - ] - mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed) - - # device optimizations - opt_device = tvm.transform.Sequential( - [ - tvm.tir.transform.Filter( - lambda f: "calling_conv" in f.attrs - and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH - ), - tvm.tir.transform.LowerWarpMemory(), - tvm.tir.transform.Simplify(), - tvm.tir.transform.LowerDeviceStorageAccessInfo(), - tvm.tir.transform.LowerCustomDatatypes(), - tvm.tir.transform.LowerIntrin(), - ] - ) - mod_dev = opt_device(mod_mixed) - - # host optimizations - opt_host = tvm.transform.Sequential( - [ - tvm.tir.transform.Filter( - lambda f: "calling_conv" not in f.attrs - or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH - ), - tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)), - tvm.tir.transform.LowerTVMBuiltin(), - tvm.tir.transform.LowerDeviceStorageAccessInfo(), - tvm.tir.transform.LowerCustomDatatypes(), - tvm.tir.transform.LowerIntrin(), - tvm.tir.transform.CombineContextCall(), - ] - ) - mod_host = opt_host(mod_mixed) - - if device_type == ndarray.cpu(0).device_type and target_host == target: - assert len(mod_dev.functions) == 0 - if "gpu" in target.keys and len(mod_dev.functions) == 0: - warnings.warn( - "Specified target %s, but cannot find device code, did you do " "bind?" % target - ) - - rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None - return mod_host, rt_mod_dev - - def build( inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, @@ -237,7 +146,8 @@ def build( Parameters ---------- - inputs : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]] + inputs : Union[tvm.te.schedule.Schedule, + tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]] The input to be built args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] @@ -253,7 +163,7 @@ def build( setup the dimensions and parameters correctly. target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, - otherwise a stackvm intepreter is used. + otherwise a stackvm interpreter is used. name : Optional[str] The name of result function. @@ -350,21 +260,11 @@ def build( target_input_mod, target_host ) - mod_host_all = tvm.IRModule({}) - - device_modules = [] - for tar, input_mod in target_input_mod.items(): - mod_host, mdev = _build_for_device(input_mod, tar, target_host) - mod_host_all.update(mod_host) - device_modules.append(mdev) - - # Generate a unified host module. - rt_mod_host = codegen.build_module(mod_host_all, target_host) + rt_mod_host = _driver_ffi.preprocess_module(target_input_mod, target_host) - # Import all modules. - for mdev in device_modules: - if mdev: - rt_mod_host.import_module(mdev) + target_input_mod, target_host = Target.check_and_update_host_consist( + target_input_mod, target_host + ) if not isinstance(target_host, Target): target_host = Target(target_host) diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index dab855abfb11..92d13a99acd5 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -21,7 +21,7 @@ import logging import time from copy import deepcopy -from typing import Optional, Dict, List, Union +from typing import Any, Optional, Dict, List, Union from urllib.parse import urlparse @@ -38,6 +38,7 @@ from .common import TVMCException from .main import register_parser from .model import TVMCModel +from .target import generate_target_args, reconstruct_target_args # pylint: disable=invalid-name @@ -106,16 +107,14 @@ def add_tune_parser(subparsers): help="hostname (required) and port (optional, defaults to 9090) of the RPC tracker, " "e.g. '192.168.0.100:9999'", ) - parser.add_argument( - "--target", - help="compilation target as plain string, inline JSON or path to a JSON file", - required=True, - ) + + generate_target_args(parser) parser.add_argument( "--target-host", help="the host compilation target, defaults to 'llvm'", default="llvm", ) + parser.add_argument("--timeout", type=int, default=10, help="compilation timeout, in seconds") parser.add_argument( "--trials", @@ -286,6 +285,7 @@ def drive_tune(args): hardware_params=hardware_params, include_simple_tasks=args.include_simple_tasks, log_estimated_latency=args.log_estimated_latency, + additional_target_options=reconstruct_target_args(args), ) @@ -311,6 +311,7 @@ def tune_model( hardware_params: Optional[HardwareParams] = None, include_simple_tasks: bool = False, log_estimated_latency: bool = False, + additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, ): """Use tuning to automatically optimize the functions in a model. @@ -367,13 +368,15 @@ def tune_model( the autoscheduler. log_estimated_latency : bool, optional If using the autoscheduler, write the estimated latency at each step of tuning to file. + additional_target_options: Optional[Dict[str, Dict[str, Any]]] + Additional target options in a dictionary to combine with initial Target arguments Returns ------- tuning_records : str The path to the produced tuning log file. """ - target, extra_targets = common.target_from_cli(target) + target, extra_targets = common.target_from_cli(target, additional_target_options) target, target_host = Target.check_and_update_host_consist(target, target_host) # TODO(jwfromm) Remove this deepcopy once AlterOpLayout bug that mutates source # model is fixed. For now, creating a clone avoids the issue. diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 9ef2f6f1fbfa..1ee24cf69d44 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -80,7 +80,7 @@ def convert_graph_layout(mod, desired_layout): ) -def validate_targets(parse_targets): +def validate_targets(parse_targets, additional_target_options=None): """ Apply a series of validations in the targets provided via CLI. """ @@ -104,6 +104,15 @@ def validate_targets(parse_targets): f"Found: {verbose_tvm_targets}." ) + if additional_target_options is not None: + for target_name in additional_target_options: + if not any([target for target in parse_targets if target["name"] == target_name]): + first_option = list(additional_target_options[target_name].keys())[0] + raise TVMCException( + f"Passed --target-{target_name}-{first_option}" + f" but did not specify {target_name} target" + ) + def tokenize_target(target): """ @@ -261,7 +270,21 @@ def is_inline_json(target): return False -def target_from_cli(target): +def _combine_target_options(target, additional_target_options=None): + if additional_target_options is None: + return target + if target["name"] in additional_target_options: + target["opts"].update(additional_target_options[target["name"]]) + return target + + +def _recombobulate_target(target): + name = target["name"] + opts = " ".join([f"-{key}={value}" for key, value in target["opts"].items()]) + return f"{name} {opts}" + + +def target_from_cli(target, additional_target_options=None): """ Create a tvm.target.Target instance from a command line interface (CLI) string. @@ -272,6 +295,10 @@ def target_from_cli(target): compilation target as plain string, inline JSON or path to a JSON file + additional_target_options: Optional[Dict[str, Dict[str,str]]] + dictionary of additional target options to be + combined with parsed targets + Returns ------- tvm.target.Target @@ -298,18 +325,22 @@ def target_from_cli(target): except ValueError as ex: raise TVMCException(f"Error parsing target string '{target}'.\nThe error was: {ex}") - validate_targets(parsed_targets) - tvm_targets = [t for t in parsed_targets if t["is_tvm_target"]] + validate_targets(parsed_targets, additional_target_options) + tvm_targets = [ + _combine_target_options(t, additional_target_options) + for t in parsed_targets + if t["is_tvm_target"] + ] # Validated target strings have 1 or 2 tvm targets, otherwise # `validate_targets` above will fail. if len(tvm_targets) == 1: - target = tvm_targets[0]["raw"] + target = _recombobulate_target(tvm_targets[0]) target_host = None else: assert len(tvm_targets) == 2 - target = tvm_targets[0]["raw"] - target_host = tvm_targets[1]["raw"] + target = _recombobulate_target(tvm_targets[0]) + target_host = _recombobulate_target(tvm_targets[1]) extra_targets = [t for t in parsed_targets if not t["is_tvm_target"]] @@ -387,7 +418,7 @@ def parse_shape_string(inputs_string): ---------- inputs_string: str A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that - indicates the desired shape for specific model inputs. Colons and forward slashes + indicates the desired shape for specific model inputs. Colons, forward slashes and dots within input_names are supported. Spaces are supported inside of dimension arrays. Returns @@ -401,7 +432,8 @@ def parse_shape_string(inputs_string): # * Spaces inside arrays # * forward slashes inside names (but not at the beginning or end) # * colons inside names (but not at the beginning or end) - pattern = r"(?:\w+\/)?[:\w]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" + # * dots inside names + pattern = r"(?:\w+\/)?[:\w.]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" input_mappings = re.findall(pattern, inputs_string) if not input_mappings: raise argparse.ArgumentTypeError( diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 9eb85a4934cb..7623a141c27a 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -19,7 +19,7 @@ """ import logging import os.path -from typing import Optional, Dict, List, Union, Callable +from typing import Any, Optional, Dict, List, Union, Callable from pathlib import Path import tvm @@ -30,6 +30,7 @@ from . import common, composite_target, frontends from .model import TVMCModel, TVMCPackage from .main import register_parser +from .target import generate_target_args, reconstruct_target_args # pylint: disable=invalid-name @@ -91,11 +92,7 @@ def add_compile_parser(subparsers): "times, each one to set one configuration value, " "e.g. '--pass-config relay.backend.use_auto_scheduler=0'.", ) - parser.add_argument( - "--target", - help="compilation targets as comma separated string, inline JSON or path to a JSON file.", - required=True, - ) + generate_target_args(parser) parser.add_argument( "--tuning-records", metavar="PATH", @@ -154,6 +151,7 @@ def drive_compile(args): desired_layout=args.desired_layout, disabled_pass=args.disabled_pass, pass_context_configs=args.pass_config, + additional_target_options=reconstruct_target_args(args), ) return 0 @@ -172,6 +170,7 @@ def compile_model( desired_layout: Optional[str] = None, disabled_pass: Optional[str] = None, pass_context_configs: Optional[List[str]] = None, + additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, ): """Compile a model from a supported framework into a TVM module. @@ -215,6 +214,8 @@ def compile_model( pass_context_configs: list[str], optional List of strings containing a set of configurations to be passed to the PassContext. + additional_target_options: Optional[Dict[str, Dict[str, Any]]] + Additional target options in a dictionary to combine with initial Target arguments Returns @@ -230,7 +231,7 @@ def compile_model( if desired_layout: mod = common.convert_graph_layout(mod, desired_layout) - tvm_target, extra_targets = common.target_from_cli(target) + tvm_target, extra_targets = common.target_from_cli(target, additional_target_options) tvm_target, target_host = Target.check_and_update_host_consist(tvm_target, target_host) for codegen_from_cli in extra_targets: diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py new file mode 100644 index 000000000000..7a078b8be087 --- /dev/null +++ b/python/tvm/driver/tvmc/target.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This file contains functions for processing target inputs for the TVMC CLI +""" + +from tvm.target import Target + +# We can't tell the type inside an Array but all current options are strings so +# it can default to that. Bool is used alongside Integer but aren't distinguished +# between as both are represented by IntImm +INTERNAL_TO_NATIVE_TYPE = {"runtime.String": str, "IntImm": int, "Array": str} +INTERNAL_TO_HELP = {"runtime.String": " string", "IntImm": "", "Array": " options"} + + +def _generate_target_kind_args(parser, kind): + target_group = parser.add_argument_group(f"target {kind.name}") + for target_option, target_type in kind.options.items(): + if target_type in INTERNAL_TO_NATIVE_TYPE: + target_group.add_argument( + f"--target-{kind.name}-{target_option}", + type=INTERNAL_TO_NATIVE_TYPE[target_type], + help=f"target {kind.name} {target_option}{INTERNAL_TO_HELP[target_type]}", + ) + + +def generate_target_args(parser): + """Walks through the TargetKind registry and generates arguments for each Target's options""" + parser.add_argument( + "--target", + help="compilation target as plain string, inline JSON or path to a JSON file", + required=True, + ) + target_kinds = Target.list_kinds() + for target_kind in target_kinds: + target = Target(target_kind) + _generate_target_kind_args(parser, target.kind) + + +def _reconstruct_target_kind_args(args, kind): + kind_options = {} + for target_option, target_type in kind.options.items(): + if target_type in INTERNAL_TO_NATIVE_TYPE: + var_name = f"target_{kind.name}_{target_option.replace('-', '_')}" + option_value = getattr(args, var_name) + if option_value is not None: + kind_options[target_option] = getattr(args, var_name) + return kind_options + + +def reconstruct_target_args(args): + """Reconstructs the target options from the arguments""" + target_kinds = Target.list_kinds() + reconstructed = {} + for target_kind in target_kinds: + target = Target(target_kind) + kind_options = _reconstruct_target_kind_args(args, target.kind) + if kind_options: + reconstructed[target.kind.name] = kind_options + return reconstructed diff --git a/python/tvm/meta_schedule/runner/local_runner.py b/python/tvm/meta_schedule/runner/local_runner.py index e3c78d741b20..caa266f97eb3 100644 --- a/python/tvm/meta_schedule/runner/local_runner.py +++ b/python/tvm/meta_schedule/runner/local_runner.py @@ -239,13 +239,11 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: result: List[float] = future.result() error_message: str = None except TimeoutError as exception: - result: List[float] = None - error_message: str = ( - f"LocalRunner: Timeout, killed after {self.timeout_sec} seconds\n" - ) + result = None + error_message = f"LocalRunner: Timeout, killed after {self.timeout_sec} seconds\n" except Exception as exception: # pylint: disable=broad-except - result: List[float] = None - error_message: str = "LocalRunner: An exception occurred\n" + str(exception) + result = None + error_message = "LocalRunner: An exception occurred\n" + str(exception) local_future = LocalRunnerFuture(res=result, error_message=error_message) results.append(local_future) return results diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index 3afdff6de77e..15f8295f2524 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -41,7 +41,7 @@ class ReplayTrace(SearchStrategy): def __init__(self, num_trials_per_iter: int, num_trials_total: int): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.ReplayTrace, # pylint: disable=no-member + _ffi_api.ReplayTrace, # type: ignore # pylint: disable=no-member num_trials_per_iter, num_trials_total, ) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 72713155c41d..d270ea61f6dc 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -56,7 +56,7 @@ def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None: The argument information. """ self.__init_handle_by_constructor__( - _ffi_api.MeasureCandidate, # pylint: disable=no-member + _ffi_api.MeasureCandidate, # type: ignore # pylint: disable=no-member sch, args_info, ) @@ -80,7 +80,7 @@ def initialize_with_tune_context( tune_context : TuneContext The tuning context for initialization. """ - _ffi_api.SearchStrategyInitializeWithTuneContext( # pylint: disable=no-member + _ffi_api.SearchStrategyInitializeWithTuneContext( # type: ignore # pylint: disable=no-member self, tune_context ) @@ -92,11 +92,11 @@ def pre_tuning(self, design_spaces: List[Schedule]) -> None: design_spaces : List[Schedule] The design spaces for pre-tuning. """ - _ffi_api.SearchStrategyPreTuning(self, design_spaces) # pylint: disable=no-member + _ffi_api.SearchStrategyPreTuning(self, design_spaces) # type: ignore # pylint: disable=no-member def post_tuning(self) -> None: """Post-tuning for the search strategy.""" - _ffi_api.SearchStrategyPostTuning(self) # pylint: disable=no-member + _ffi_api.SearchStrategyPostTuning(self) # type: ignore # pylint: disable=no-member def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]: """Generate measure candidates from design spaces for measurement. @@ -106,7 +106,7 @@ def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]: measure_candidates : Optional[List[IRModule]] The measure candidates generated, None if finished. """ - return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # pylint: disable=no-member + return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # type: ignore # pylint: disable=no-member def notify_runner_results(self, results: List[RunnerResult]) -> None: """Update the search strategy with profiling results. @@ -116,7 +116,7 @@ def notify_runner_results(self, results: List[RunnerResult]) -> None: results : List[RunnerResult] The profiling results from the runner. """ - _ffi_api.SearchStrategyNotifyRunnerResults(self, results) # pylint: disable=no-member + _ffi_api.SearchStrategyNotifyRunnerResults(self, results) # type: ignore # pylint: disable=no-member @register_object("meta_schedule.PySearchStrategy") @@ -142,7 +142,7 @@ def f_notify_runner_results(results: List["RunnerResult"]) -> None: self.notify_runner_results(results) self.__init_handle_by_constructor__( - _ffi_api.SearchStrategyPySearchStrategy, # pylint: disable=no-member + _ffi_api.SearchStrategyPySearchStrategy, # type: ignore # pylint: disable=no-member f_initialize_with_tune_context, f_pre_tuning, f_post_tuning, diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index b8dcfd9e7a2d..f1e21ad3ddfe 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -27,7 +27,7 @@ class TaskScheduler(Object): def tune(self) -> None: """Auto-tuning.""" - _ffi_api.TaskSchedulerTune(self) # pylint: disable=no-member + _ffi_api.TaskSchedulerTune(self) # type: ignore # pylint: disable=no-member def _set_task_stopped(self, task_id: int) -> None: """Set specific task to be stopped. @@ -37,7 +37,7 @@ def _set_task_stopped(self, task_id: int) -> None: task_id : int The task id to be stopped. """ - _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # pylint: disable=no-member + _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # type: ignore # pylint: disable=no-member def _is_task_running(self, task_id: int) -> bool: """Check whether the task is running. @@ -52,7 +52,7 @@ def _is_task_running(self, task_id: int) -> bool: bool Whether the task is running. """ - return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # pylint: disable=no-member + return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member def _join_running_task(self, task_id: int) -> None: """Wait until the task is finished. @@ -62,7 +62,7 @@ def _join_running_task(self, task_id: int) -> None: task_id : int The task id to be joined. """ - _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member + _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member def _next_task_id(self) -> int: """Fetch the next task id. @@ -72,7 +72,7 @@ def _next_task_id(self) -> int: int The next task id. """ - return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member + return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member @register_object("meta_schedule.PyTaskScheduler") @@ -98,7 +98,7 @@ def f_next_task_id() -> int: return self._next_task_id() self.__init_handle_by_constructor__( - _ffi_api.TaskSchedulerPyTaskScheduler, # pylint: disable=no-member + _ffi_api.TaskSchedulerPyTaskScheduler, # type: ignore # pylint: disable=no-member f_tune, f_set_task_stopped, f_is_task_running, @@ -110,13 +110,13 @@ def tune(self) -> None: raise NotImplementedError() def _set_task_stopped(self, task_id: int) -> None: - _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # pylint: disable=no-member + _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # type: ignore # pylint: disable=no-member def _is_task_running(self, task_id: int) -> bool: - return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # pylint: disable=no-member + return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member def _join_running_task(self, task_id: int) -> None: - _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member + _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member def _next_task_id(self) -> int: - return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member + return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index bf2ef17fb308..c79137d55dda 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -20,7 +20,7 @@ import shutil from typing import Any, Callable, List, Optional, Union -import psutil +import psutil # type: ignore import tvm from tvm._ffi import get_global_func, register_func from tvm.error import TVMError diff --git a/python/tvm/relay/backend/__init__.py b/python/tvm/relay/backend/__init__.py index 4fc2b63748db..d76459236515 100644 --- a/python/tvm/relay/backend/__init__.py +++ b/python/tvm/relay/backend/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. """Backend codegen modules for relay.""" -from . import compile_engine +from . import te_compiler diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index fd58da803623..b970aec62c6f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -208,6 +208,99 @@ def __call__(self, *args, **kwargs): pass +class EthosuDepthwiseConv2DRewriter(DFPatternCallback): + """Convert ethosu.qnn_depthwise_conv2d composite functions to ethosu_depthwise_conv2d + operators""" + + def __init__(self): + super().__init__(require_type=True) + self.pattern = ( + wildcard().has_attr( + {"Composite": ethosu_patterns.QnnDepthwiseConv2DParams.composite_name} + ) + )(wildcard()) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = ethosu_patterns.QnnDepthwiseConv2DParams(post.op.body) + params.ifm.tensor = post.args[0] + channels_map = { + "NHWC": 3, + } + if str(params.ofm.layout) not in channels_map.keys(): + raise UnsupportedLayout(str(params.ofm.layout)) + kernel_shape_map = { + "HWOI": params.weights.shape[0:2], + } + if str(params.weights.layout) not in kernel_shape_map.keys(): + raise UnsupportedLayout(str(params.weights.layout)) + + weights_values = params.weights.values + weights_values_ohwi = np.moveaxis(weights_values, [0, 1, 2, 3], [1, 2, 0, 3]) + + activation = "NONE" + # Activations requiring LUT is currently not supported, so setting it to an empty list + lut = relay.const([], "int8") + clip_min = 0 + clip_max = 0 + if params.activation: + activation = ethosu_patterns.QnnDepthwiseConv2DParams.activation_map[ + params.activation.op.name + ] + if activation == "CLIP": + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) + scale_bias = vela_api.pack_biases( + biases=params.biases.tensor.data.asnumpy(), + ifm_scale=params.ifm.q_params.scale_f32, + ifm_dtype=np.dtype(params.ifm.dtype), + weight_scales=params.weights.q_params.scale_f32, + ofm_scale=params.ofm.q_params.scale_f32, + is_activation_tanh_or_sigmoid=activation in ["TANH", "SIGMOID"], + ) + + ethosu_depthwise_conv2d = ethosu_ops.ethosu_depthwise_conv2d( + post.args[0], # IFM + relay.const(weights_values_ohwi, params.weights.values.dtype), + relay.const(scale_bias, "uint8"), + lut, + float(params.ifm.q_params.scale_f32), + int(params.ifm.q_params.zero_point), + int(params.weights.q_params.zero_point), + float(params.ofm.q_params.scale_f32), + int(params.ofm.q_params.zero_point), + kernel_shape_map[str(params.weights.layout)], + params.ofm.shape[channels_map[str(params.ofm.layout)]], + strides=params.strides, + padding=params.padding, + dilation=params.dilation, + activation=activation, + clip_min=clip_min, + clip_max=clip_max, + upscale="NONE", + ifm_layout=str(params.ifm.layout), + ofm_layout=str(params.ofm.layout), + ) + return ethosu_depthwise_conv2d + + +@ir.transform.module_pass(opt_level=1) +class LegalizeEthosUDepthwiseConv2D: + """This is the pass that wraps the EthosUDepthwiseConv2DRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(EthosuDepthwiseConv2DRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + @ir.transform.module_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -220,6 +313,7 @@ def transform_module( ) -> tvm.ir.IRModule: mod = LegalizeSplit()(mod) mod = LegalizeEthosUConv2D()(mod) + mod = LegalizeEthosUDepthwiseConv2D()(mod) return mod def __call__(self, *args, **kwargs): diff --git a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py index 0406298f23f4..1063db6a04c5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py @@ -17,3 +17,4 @@ "Relay operators for the Arm(R) Ethos(TM)-U NPU" from .convolution import ethosu_conv2d +from .depthwise import ethosu_depthwise_conv2d diff --git a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py new file mode 100644 index 000000000000..abcddf90b97c --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py @@ -0,0 +1,205 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Relay operator for depthwise convolution""" +from typing import Tuple + +import tvm +from tvm.relay.op import _make +from tvm.topi.generic import schedule_injective +from tvm.relay.op.op import OpStrategy +from tvm.relay.op import strategy as _strategy + +from ..te import depthwise_conv2d_compute + + +def _extract_ethosu_depthwise_conv2d_params(attrs, args): + """Get the parameters necessary to construct a ethosu_depthwise_conv2d compute TE + from a ethosu_depthwise_conv2d Relay call.""" + ifm = args[0] + weight = args[1] + scale_bias = args[2] + lut = args[3] + ifm_scale = attrs.ifm_scale + ifm_zero_point = attrs.ifm_zero_point + weight_zero_point = attrs.weight_zero_point + ofm_scale = attrs.ofm_scale + ofm_zero_point = attrs.ofm_zero_point + strides = attrs.strides + padding = attrs.padding + dilation = attrs.dilation + activation = attrs.activation + clip_min = attrs.clip_min + clip_max = attrs.clip_max + upscale = attrs.upscale + ifm_layout = attrs.ifm_layout + ofm_layout = attrs.ofm_layout + + return ( + ifm, + weight, + scale_bias, + lut, + ifm_scale, + ifm_zero_point, + weight_zero_point, + ofm_scale, + ofm_zero_point, + strides, + padding, + dilation, + activation, + clip_min, + clip_max, + upscale, + ifm_layout, + ofm_layout, + ) + + +@tvm.ir.register_op_attr("contrib.ethosu.depthwise_conv2d", "FTVMCompute") +def create_ethosu_depthwise_conv2d_compute(attrs, args, out_type): + """Create an ethosu_depthwise_conv2d compute op.""" + params = _extract_ethosu_depthwise_conv2d_params(attrs, args) + op = depthwise_conv2d_compute(*params) + return [op] + + +@tvm.ir.register_op_attr("contrib.ethosu.depthwise_conv2d", "FTVMStrategy") +def depthwise_conv2d_strategy_ethosu(attrs, inputs, out_type, target): + strategy = OpStrategy() + strategy.add_implementation( + create_ethosu_depthwise_conv2d_compute, + _strategy.wrap_topi_schedule(schedule_injective), + name="ethosu_depthwise_conv2d", + ) + return strategy + + +def ethosu_depthwise_conv2d( + ifm: tvm.relay.Expr, + weight: tvm.relay.Expr, + scale_bias: tvm.relay.Expr, + lut: tvm.relay.Expr, + ifm_scale: float, + ifm_zero_point: int, + weight_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + kernel_shape: Tuple[int, int], + ofm_channels: int, + strides: Tuple[int, int] = (1, 1), + padding: Tuple[int, int, int, int] = (0, 0, 0, 0), + dilation: Tuple[int, int] = (1, 1), + activation: str = "NONE", + clip_min: int = 0, + clip_max: int = 0, + upscale: str = "NONE", + ifm_layout: str = "NHWC", + ofm_layout: str = "NHWC", +) -> tvm.relay.Call: + """This is a quantized 2D depthwise convolution operation as supported + by the NPU. It accepts either NHWC or NHCWB16 format + for the input data and OHWI format for the kernel weights. + + Reference: https://developer.arm.com/documentation/102420/0200/ + + Note that the per-channel weight scale and bias tensor must be + packed together into a combined tensor of uint80s. This is represented + in TVM by a (channels, 10) tensor of type uint8. For more detail, + refer to the Technical Reference Manual linked above. + + Parameters + ---------- + ifm : tvm.relay.Expr + The Input Feature Map tensor (IFM). + weight : tvm.relay.Expr + The weight tensor. + scale_bias : tvm.relay.Expr + The packed per-channel weight scale and bias tensor. + lut : tvm.relay.Expr + The look-up table values to use if activation = "LUT" + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + weight_zero_point : int + The quantization zero point for the weight tensor. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + kernel_shape : tuple of int + The 2 dimensional kernel shape as (kernel_height, kernel_width). + ofm_channels : int + The number of OFM channels. + strides : tuple of int, optional + The 2 dimensional strides as (stride_height, stride_width). + padding : tuple of int, optional + The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + dilation : tuple of int, optional + The 2 dimensional dilation as (dilation_height, dilation_width). + activation : str, optional + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform + the activation function. + clip_min : int, optional + The minimum clipping value if activation = "CLIP" + clip_max : int, optional, + The maximum clipping value if activation = "CLIP" + upscale : str, optional + The 2x2 upscaling mode to apply to the Input Feature Map tensor. + "NONE" - no upscaling. + "NEAREST" - upscale using nearest neighbour. + "ZEROS" - upscale using zeros. + ifm_layout : str, optional + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_layout : str, optional + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + out : tvm.relay.Call + A call to the ethosu_depthwise_conv2d op. + + """ + return _make.ethosu_depthwise_conv2d( + ifm, + weight, + scale_bias, + lut, + ifm_scale, + ifm_zero_point, + weight_zero_point, + ofm_scale, + ofm_zero_point, + kernel_shape, + ofm_channels, + strides, + padding, + dilation, + activation, + clip_min, + clip_max, + upscale, + ifm_layout, + ofm_layout, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py index 7ca5de3c160c..5dcdd4dcf602 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py @@ -17,3 +17,4 @@ """Tensor Expressions for the NPU""" from .convolution import * +from .depthwise import * diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py new file mode 100644 index 000000000000..35ae7f9a700a --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Tensor Expressions for depthwise convolutions""" +from typing import Tuple, Union, List + +from tvm import te +from .dma import dma_ofm_compute, dma_ifm_compute + + +def depthwise_conv2d_compute( + ifm: te.Tensor, + weight: te.Tensor, + scale_bias: te.Tensor, + lut: te.Tensor, + ifm_scale: float, + ifm_zero_point: int, + weight_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + strides: Tuple[int, int], + padding: Tuple[int, int, int, int], + dilation: Union[Tuple[int, int], List[int]], + activation: str, + clip_min: int, + clip_max: int, + upscale: str, + ifm_layout: str, + ofm_layout: str, +) -> te.Tensor: + """A compute operator representing the capabilities of 2D convolution for the NPU. + + Parameters + ---------- + ifm : te.Tensor + The Input Feature Map tensor (IFM). + weight : te.Tensor + The weight tensor. + scale_bias : te.Tensor + The packed per-channel weight scale and bias tensor. + lut : te.Tensor + The look-up table values to use if activation = "LUT". + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + weight_zero_point : int + The quantization zero point for the weight tensor. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + strides : tuple + The 2 dimensional strides as (stride_height, stride_width). + padding : tuple + The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + dilation : Union[int, tuple, list] + The 2 dimensional dilation as (dilation_height, dilation_width). + activation : str + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + clip_min : int + The minimum clipping value if activation = "CLIP". + clip_max : int + The maximum clipping value if activation = "CLIP". + upscale : str + The 2x2 upscaling mode to apply to the Input Feature Map tensor. + "NONE" - no upscaling. + "NEAREST" - upscale using nearest neighbour. + "ZEROS" - upscale using zeros. + ifm_layout : str + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_layout : str + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + te.Tensor + The OFM tensor. + + """ + assert ifm.shape[0] == 1, f"Only batch size 1 is supported" + assert ifm_layout in {"NHWC", "NHCWB16"} + assert ofm_layout in {"NHWC", "NHCWB16"} + + stride_h, stride_w = strides + dilation_h, dilation_w = dilation + channels, kernel_h, kernel_w, _ = weight.shape + + # Compute operation for the IFM DMA pipeline + dmaed_ifm = dma_ifm_compute(ifm, ifm_layout, ifm_zero_point, ifm_scale, channels, padding) + + # 2D Depthwise Convolution compute operation + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + ofm_height = (dmaed_ifm.shape[1] - dilated_kernel_h) // stride_h + 1 + ofm_width = (dmaed_ifm.shape[2] - dilated_kernel_w) // stride_w + 1 + rh = te.reduce_axis((0, kernel_h), name="ry") + rw = te.reduce_axis((0, kernel_w), name="rx") + + depthwise_conv2d_attrs = { + "op": "ethosu_depthwise_conv2d", + "weight_zero_point": weight_zero_point, + "activation": activation, + "upscale": upscale, + "clip_min": clip_min, + "clip_max": clip_max, + "stride_h": stride_h, + "stride_w": stride_w, + "dilation_h": dilation_h, + "dilation_w": dilation_w, + } + + depthwise = te.compute( + (1, ofm_height, ofm_width, channels), + lambda nn, hh, ww, cc: te.sum( + dmaed_ifm( + nn, hh * stride_h + rh * dilation_h, ww * stride_w + rw * dilation_w, cc + ).astype(ifm.dtype) + * weight[cc, rh, rw, 0].astype(ifm.dtype) + # This is a trick to load 10 elements of the scale_bias at once, not accurate maths + + (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype), + axis=[rh, rw], + ), + name="ethosu_depthwise_conv2d", + attrs=depthwise_conv2d_attrs, + ) + + # Compute operation for the OFM DMA pipeline + return dma_ofm_compute(depthwise, ofm_layout, ofm_zero_point, ofm_scale, channels) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index c59a386fefbb..c792ade06643 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""The integration of Arm(R) Ethos(TM)-U NPU TIR compiler""" +"""The integration of the Arm(R) Ethos(TM)-U NPU TIR compiler.""" import tvm from tvm import relay from tvm.relay.expr_functor import ExprMutator -from tvm.driver.build_module import get_binds +from tvm.driver.build_module import schedule_to_module from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants from .scheduler import schedule @@ -29,7 +29,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target. The resulting TIR module will contain a single function - that comprises of a sequence of tir.extern_calls to NPU + that consists of a sequence of tir.extern_calls to NPU operations. Parameters @@ -64,22 +64,17 @@ def lower_ethosu(sch, args, const_dict, name="main"): "no_unroll_loop_with_extent_one": True, }, "tir.UnrollLoop": {"auto_max_depth": -1}, + "tir.noalias": True, + "tir.debug_keep_trivial_loop": True, } # Merge two configs curr_cfg = {**curr_cfg, **tir_compiler_cfg} sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True) - compact = tvm.te.schedule.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, None) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) - - func = func.with_attr("global_symbol", name) - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: func}) with tvm.transform.PassContext(config=curr_cfg): + mod = schedule_to_module(sch, args, name) + mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.UnrollLoop()(mod) @@ -96,20 +91,20 @@ def lower_ethosu(sch, args, const_dict, name="main"): def lower_to_te(prim_func): - """Lower a Relay primitive function to a Tensor Expression graph. + """Lower a Relay primitive function to a Tensor Expression in an unscheduled CachedFunc. Parameters ---------- prim_func : tvm.relay.Function - The Relay function to lowerethosu_runtime([]). + The Relay function to lower. Returns ------- - out : TEGraph - The lowered Tensor Expression graph. + out : CachedFunc + The lowered Tensor Expression as part of a CachedFunc. """ - f = tvm._ffi.get_global_func("relay.backend.contrib.ethosu.LowerToTE") + f = tvm._ffi.get_global_func("relay.backend.LowerToTE") return f(prim_func) @@ -193,7 +188,7 @@ def lower_to_tir(func, cascader=None): func, consts = extract_constants(func) mod = tvm.IRModule.from_expr(func) func = relay.transform.InferType()(mod)["main"] - te_graph = lower_to_te(func) - s = schedule(te_graph, consts, cascader) - mod, consts = lower_ethosu(s, te_graph, consts) + cached_func = lower_to_te(func) + s = schedule(cached_func, consts, cascader) + mod, consts = lower_ethosu(s, cached_func, consts) return mod, consts diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 33fbdcd2b24f..fd7fa293ccfb 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Extract information from the convolution operators in TIR.""" +"""Extract parameters from the convolution operators in TIR.""" import tvm from ..vela_api import SCALE_BIAS_LENGTH from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py new file mode 100644 index 000000000000..27111a970b27 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the depthwise convolution operators in TIR.""" +from typing import Dict, Tuple +import tvm +from ..vela_api import SCALE_BIAS_LENGTH +from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores +from .dma import get_ifm_params, get_ofm_params +from .spec import ( + SerialKernel, + SerialAddressRange, + SerialActivation, + Serial2DDepthwise, +) + + +def get_depthwise_conv2d_params( + stmt: tvm.tir.AttrStmt, + producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], + consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], +) -> Tuple[Serial2DDepthwise, tvm.tir.Var, tvm.tir.Var]: + """Get the parameters necessary to construct a call_extern for a depthwise_conv2d. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a depthwise loop nest. + producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] + A dictionary to associate pointers with the loop nest + that produces their values. + consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + Serial2DDepthwise + The parameters needed to construct a 2D depthwise. + output_pointer : tvm.tir.Var + The output pointer of the convolution operation. + replace_pointer : tvm.tir.Var + The output pointer of the DMA write operation, which is to replace + the convolution output pointer. + + """ + attrs, body = get_op_attrs(stmt) + _, _, _, _, _, inner = get_outer_loops(body, "NHWC") + rh = inner + rw = rh.body + # loads = [output, input, weights, scale_bias, scale_bias] + loads = get_loads(rw.body) + # stores = [output] + stores = get_stores(rw.body) + input_pointer = loads[1].buffer_var + output_pointer = stores[0].buffer_var + # Get feature map info + serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) + serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + # Get kernel info + serial_kernel = SerialKernel( + width=int(rw.extent), + height=int(rh.extent), + stride_w=int(attrs["stride_w"]), + stride_h=int(attrs["stride_h"]), + dilation_w=int(attrs["dilation_w"]), + dilation_h=int(attrs["dilation_h"]), + ) + # Get scale_bias info + scale_bias_load = loads[3] + scale_bias_base = get_base_address(scale_bias_load.index) + serial_scale_bias = SerialAddressRange( + address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base), + length=SCALE_BIAS_LENGTH * serial_ofm[3], + ) + # Get weight info + weight_load = loads[2] + weight_base = get_base_address(weight_load.index) + serial_weight = SerialAddressRange( + address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base), + length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1], + ) + # Get activation info + serial_activation = SerialActivation( + op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] + ) + + return ( + Serial2DDepthwise( + ifm=serial_ifm, + ofm=serial_ofm, + kernel=serial_kernel, + weight=serial_weight, + weight_zero_point=attrs["weight_zero_point"], + scale_bias=serial_scale_bias, + padding=serial_padding, + activation=serial_activation, + upscale="NONE", + ), + output_pointer, + replace_pointer, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py index ecd402d63309..a116e51c5b7c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Extract information from the DMA operators in TIR.""" +"""Extract parameters from the DMA operators in TIR.""" import tvm from .utils import get_outer_loops, get_base_address, get_strides, get_op_attrs from .spec import SerialFeatureMap, SerialPadding diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 1af44962c141..761c8aad7bb1 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler""" +"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler.""" import numpy as np # type: ignore import tvm from tvm.relay.backend.contrib.ethosu import vela_api from .convolution import get_conv2d_params +from .depthwise import get_depthwise_conv2d_params from .transform import get_copy_params from .utils import get_weights_pointer, get_scale_bias_pointer @@ -52,6 +53,7 @@ def ReplaceOperators(): op_map = { "ethosu_conv2d": get_conv2d_params, "ethosu_copy": get_copy_params, + "ethosu_depthwise_conv2d": get_depthwise_conv2d_params, } pointer_to_producer = {} pointer_to_consumer = {} @@ -299,7 +301,7 @@ def EncodeConstants(const_dict): pointer_to_buffer = {} rewrite_buffer = {} rewrite_pointer = {} - accel_type = vela_api.get_target_accel_type() # type: ignore + accel_config = vela_api.get_accelerator_config() def _align_scale_bias(tir_extern_call, bias): """Align the scale_bias to 16 bytes.""" @@ -314,7 +316,7 @@ def _align_scale_bias(tir_extern_call, bias): def _encode_weights(tir_extern_call, weights): """Encode the weights for a TIR extern call.""" - value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_type) + value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_config) value = np.frombuffer(value_bytes, dtype="uint8") return value diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index 5d9027bf2078..7f892d0c602a 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -15,17 +15,17 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Different schedulers for Arm(R) Ethos(TM)-U NPU""" +"""Scheduling for Arm(R) Ethos(TM)-U NPU.""" import tvm -def schedule(te_graph, const_dict, cascader=None): - """Schedule a TE graph for NPU compilation. +def schedule(cached_func, const_dict, cascader=None): + """Schedule a CachedFunc for NPU compilation. Parameters ---------- - te_graph - The TE graph to schedule. + cached_func : CachedFunc + The CachedFunc to schedule. const_dict : dict of int to numpy.ndarray The constant dictionary. cascader : callable, optional @@ -38,10 +38,10 @@ def schedule(te_graph, const_dict, cascader=None): The completed schedule for the graph. """ - s = tvm.te.create_schedule([t.op for t in te_graph.outputs]) + s = tvm.te.create_schedule([t.op for t in cached_func.outputs]) if cascader: - cascader(te_graph, const_dict, s) - inline_no_ops(te_graph, s) + cascader(cached_func, const_dict, s) + inline_no_ops(cached_func, s) schedule_pragmas(s) schedule_cache_reads(s) return s @@ -96,7 +96,7 @@ def total_cascader(stripe_size): """ - def _cascader(te_graph, const_dict, sch): + def _cascader(cached_func, const_dict, sch): scheduled = set() def _visit(tensor, stage, ax): @@ -106,8 +106,8 @@ def _visit(tensor, stage, ax): for input_tensor in tensor.op.input_tensors: _visit(input_tensor, stage, ax) - assert len(te_graph.outputs) == 1 - out = te_graph.outputs[0] + assert len(cached_func.outputs) == 1 + out = cached_func.outputs[0] oi, _ = tile_nd(sch, out, stripe_size) for ax in oi: sch[out].unroll(ax) @@ -126,14 +126,14 @@ def copy_constants(): The planning function. """ - def _planner(te_graph, const_dict, sch): + def _planner(cached_func, const_dict, sch): planned = set() # type: ignore def _visit(tensor, reader): if tensor is not planned: planned.add(tensor) if isinstance(tensor.op, tvm.te.PlaceholderOp): - index = list(te_graph.inputs).index(tensor) + index = list(cached_func.inputs).index(tensor) if index in const_dict: sch.cache_read(tensor, "global", [reader]) @@ -141,7 +141,7 @@ def _visit(tensor, reader): for input_tensor in tensor.op.input_tensors: _visit(input_tensor, tensor) - for output_tensor in te_graph.outputs: + for output_tensor in cached_func.outputs: _visit(output_tensor, None) return _planner @@ -216,7 +216,7 @@ def _detect_cache_read(stage): stage.pragma(fax, "op", "ethosu_copy") -def inline_no_ops(te_graph, sch): +def inline_no_ops(cached_func, sch): """Inline 'no-ops' - operations that in principle do nothing. Modifies the schedule in-place. For now we inline reshape and @@ -224,8 +224,8 @@ def inline_no_ops(te_graph, sch): Parameters ---------- - te_graph - The TE graph. + cached_func : CachedFunc + The cached func. sch : tvm.te.Schedule The schedule. @@ -241,7 +241,7 @@ def _visit(tensor): for input_tensor in tensor.op.input_tensors: _visit(input_tensor) - for out in te_graph.outputs: + for out in cached_func.outputs: _visit(out) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py index 3ecbcd5f3cdc..ff019c7783db 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py @@ -203,7 +203,7 @@ def __init__( class Serial2DDepthwise(SerializableFormat): """Specialization class to retrieve arguments of - a ethosu.depthwise2d tir extern call on a predefined ordering""" + a ethosu.depthwise_conv2d TIR extern call on a predefined ordering""" def __init__( self, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py index 0403ce2c7e8f..f50975c83838 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Extract information from the transform operators in TIR.""" +"""Extract parameters from the transform operators in TIR.""" import tvm from .spec import SerialCopy from .utils import get_base_address, get_op_attrs diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py index 7d6fd3bf82d8..de1c0ab19f6e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name -"""Helper utility functions used by the TIR compiler""" +"""Helper utility functions used by the NPU TIR compiler""" import tvm from tvm import arith @@ -23,7 +23,8 @@ # TODO(@mbaret): Formalise this with a specification def get_weights_pointer(tir_extern_call): """Get the weights pointer from a NPU extern call if it exists""" - if tir_extern_call.args[0] == "ethosu_conv2d": + supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"] + if tir_extern_call.args[0] in supported_ops: return tir_extern_call.args[41].buffer_var return None @@ -31,7 +32,8 @@ def get_weights_pointer(tir_extern_call): # TODO(@mbaret): Formalise this with a specification def get_scale_bias_pointer(tir_extern_call): """Get the scale_bias pointer from a NPU extern call if it exists""" - if tir_extern_call.args[0] == "ethosu_conv2d": + supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"] + if tir_extern_call.args[0] in supported_ops: return tir_extern_call.args[44].buffer_var return None diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 4b28dc5b191e..bcae01a10214 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -18,7 +18,7 @@ the Relay to TIR compilation process, to Vela API calls to generate command stream. """ -from typing import NamedTuple +from typing import Dict, NamedTuple, Tuple, Union from enum import auto from enum import Enum import numpy as np # type: ignore @@ -32,7 +32,7 @@ class BufferType(Enum): - """The buffer types the codegen supports""" + """The type of information that a buffer contains.""" constant = auto() input_or_output = auto() @@ -50,7 +50,7 @@ class BufferType(Enum): class BufferInfo(NamedTuple): - """A data structure to hold metadata of the buffer""" + """A data structure to hold metadata of the buffer.""" # If the buffer holds constants, the values will contain that otherwise None values: np.ndarray @@ -90,9 +90,9 @@ def translate(tir_module, params): for extern_call in extern_calls: _npu_ops.append(translate_ethosu_tir_extern_call(extern_call)) _npu_ops, constant_tensor, scratch_size = assign_addresses(buffer_info, _npu_ops) - target_accel_type = vela_api.get_target_accel_type() - cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_type) - payload = vapi.npu_create_driver_payload(cmds, target_accel_type) + target_accel_config = vela_api.get_accelerator_config() + cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_config) + payload = vapi.npu_create_driver_payload(cmds, target_accel_config) hex_value = "" if constant_tensor is None else constant_tensor.tobytes().hex() return payload.hex(), hex_value, scratch_size @@ -125,9 +125,10 @@ def populate_extern_calls(stmt): return extern_calls -def extract_buffer_info(mod, param_dict): - """ - This function is to read the tvm.IRModule that +def extract_buffer_info( + mod: tvm.IRModule, param_dict: Dict[int, np.ndarray] +) -> Dict[str, BufferInfo]: + """This function is to read the tvm.IRModule that contains Relay to TIR compiled IRModule. Thereafter, this will extract the buffer information as the shape and constant data (if any). @@ -136,12 +137,14 @@ def extract_buffer_info(mod, param_dict): ---------- mod : tvm.IRModule The NPU TIR IRModule. - param_dict : dict + param_dict : Dict[int, np.ndarray] A dictionary containing param idx --> const numpy.NDArray + Returns ------- - dict - a dictionary of buffer names --> BufferInfo + dict : Dict[str, BufferInfo] + A dictionary of buffer names --> BufferInfo + """ buffer_info = dict() # There should only be a single function @@ -299,6 +302,7 @@ def translate_ethosu_tir_extern_call(tir_extern_call): supported_extern_calls = { "ethosu_conv2d": translate_ethosu_conv2d, "ethosu_copy": translate_ethosu_copy, + "ethosu_depthwise_conv2d": translate_ethosu_depthwise_conv2d, } ext_call_type = tir_extern_call.args[0].value assert ext_call_type in supported_extern_calls.keys(), f"{ext_call_type} is not yet supported" @@ -327,14 +331,15 @@ def translate_ethosu_copy(tir_extern_call): return _create_npu_dma_op(serial_object) -def _convert_clip_bounds(npu_op): - """ - This function will convert the min and max value +def _convert_clip_bounds(npu_op: vapi.NpuBlockOperation): + """This function will convert the min and max value of clip activations to non quantized floats as expected by the API. + Parameters ---------- - npu_op : ethosu.vela.api.NpuBlockOperation + npu_op : vapi.NpuBlockOperation + """ clip_min_quant = npu_op.activation.min clip_max_quant = npu_op.activation.max @@ -348,13 +353,14 @@ def _convert_clip_bounds(npu_op): npu_op.activation.max = clip_max_actual -def translate_ethosu_conv2d(tir_extern_call): - """This function will translate a tir extern_call - as produced by Relay to TIR compilation. +def translate_ethosu_conv2d(tir_call_extern: tvm.tir.Call) -> Tuple[vapi.NpuConv2DOperation, int]: + """This function will translate a TIR call_extern + as produced by NPU Relay to TIR compilation. + Parameters ---------- - tir_extern_call : tvm.tir.Call - This should be an tir external call that has a agreed upon ordering + tir_call_extern : tvm.tir.Call + This should be a TIR call_extern that has a agreed upon ordering for TIR Compiler. See Serial2DConvolution in tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. @@ -364,15 +370,18 @@ def translate_ethosu_conv2d(tir_extern_call): The vela object containing the params of ethosu_conv2d weights_zero_point : int The zero point of the weights + """ - # We skip the first element as it is the extern_call function name - serial_object = spec.create_serial_object(spec.Serial2DConvolution, tir_extern_call.args[1:]) + # We skip the first element as it is the call_extern function name + serial_object = spec.create_serial_object(spec.Serial2DConvolution, tir_call_extern.args[1:]) return _create_npu_op_conv2d(serial_object) -def _create_npu_op_conv2d(serial_2d_convolution): +def _create_npu_op_conv2d( + serial_2d_convolution: spec.Serial2DConvolution, +) -> Tuple[vapi.NpuConv2DOperation, int]: """This is a helper function to capture a list - of arguments to create Vela NpuConv2DOperation object + of arguments to create Vela NpuConv2DOperation object. """ npu_conv2d_op = vapi.NpuConv2DOperation() npu_conv2d_op.ifm = _create_npu_feature_map(serial_2d_convolution.ifm) @@ -391,8 +400,8 @@ def _create_npu_op_conv2d(serial_2d_convolution): _convert_clip_bounds(npu_conv2d_op) npu_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_convolution.upscale) - target_accel_type = vela_api.get_target_accel_type() # type: ignore - block_config = vela_api.get_optimal_block_config(npu_conv2d_op, target_accel_type) + accel_config = vela_api.get_accelerator_config() + block_config = vela_api.get_optimal_block_config(npu_conv2d_op, accel_config) npu_conv2d_op.block_config = block_config weights_shape_ohwi = [ npu_conv2d_op.ofm.shape.depth, @@ -408,9 +417,57 @@ def _create_npu_op_conv2d(serial_2d_convolution): return npu_conv2d_op, weights_zero_point -def _create_npu_feature_map(serial_feature_map): +def translate_ethosu_depthwise_conv2d(tir_extern_call): + """This function will translate a tir extern_call + as produced by Relay to TIR compilation. + + Parameters + ---------- + tir_extern_call : tvm.tir.Call + This should be a tir external call that has an agreed upon ordering + for NPU TIR Compiler. See Serial2DDepthwise in + tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. + + Returns + ------- + ethosu.vela.api.NpuDepthWiseOperation + The vela object containing the params of ethosu_depthwise_conv2d + weights_zero_point : int + The zero point of the weights + """ + serial_object = spec.create_serial_object(spec.Serial2DDepthwise, tir_extern_call.args[1:]) + return _create_npu_op_depthwise_conv2d(serial_object) + + +def _create_npu_op_depthwise_conv2d(serial_2d_depthwise): + npu_depthwise_conv2d_op = vapi.NpuConvDepthWiseOperation() + + npu_depthwise_conv2d_op.ifm = _create_npu_feature_map(serial_2d_depthwise.ifm) + npu_depthwise_conv2d_op.ofm = _create_npu_feature_map(serial_2d_depthwise.ofm) + npu_depthwise_conv2d_op.kernel = _create_npu_kernel(serial_2d_depthwise.kernel) + npu_depthwise_conv2d_op.weights = [_create_npu_address_range(serial_2d_depthwise.weight)] + weights_zero_point = np.int64(serial_2d_depthwise.weight_zero_point.value) + npu_depthwise_conv2d_op.biases = [_create_npu_address_range(serial_2d_depthwise.scale_bias)] + npu_depthwise_conv2d_op.padding = _create_npu_padding(serial_2d_depthwise.padding) + + npu_depthwise_conv2d_op.activation = _create_npu_activation(serial_2d_depthwise.activation) + if ( + npu_depthwise_conv2d_op.activation + and npu_depthwise_conv2d_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + ): + _convert_clip_bounds(npu_depthwise_conv2d_op) + + npu_depthwise_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_depthwise.upscale) + target_accel_config = vela_api.get_accelerator_config() + block_config = vela_api.get_optimal_block_config(npu_depthwise_conv2d_op, target_accel_config) + npu_depthwise_conv2d_op.block_config = block_config + + return npu_depthwise_conv2d_op, weights_zero_point + + +def _create_npu_feature_map(serial_feature_map: spec.SerialFeatureMap) -> vapi.NpuFeatureMap: """This is a helper function to capture a list - of arguments to create Vela NpuFeatureMap object + of arguments to create Vela NpuFeatureMap object. """ layout_map = {"NHWC": vapi.NpuLayout.NHWC, "NHCWB16": vapi.NpuLayout.NHCWB16} datatype_map = { @@ -427,14 +484,14 @@ def _create_npu_feature_map(serial_feature_map): nfm = vapi.NpuFeatureMap() nfm.data_type = datatype_map[data_type] nfm.shape = vapi.NpuShape3D( - int(serial_feature_map.height.value), - int(serial_feature_map.width.value), - int(serial_feature_map.channels.value), + int(serial_feature_map.height), + int(serial_feature_map.width), + int(serial_feature_map.channels), ) nfm.tiles = vapi.NpuTileBox( - int(serial_feature_map.tile_height_0.value), - int(serial_feature_map.tile_height_1.value), - int(serial_feature_map.tile_width_0.value), + int(serial_feature_map.tile_height_0), + int(serial_feature_map.tile_height_1), + int(serial_feature_map.tile_width_0), [ serial_feature_map.tile_address_0, serial_feature_map.tile_address_1, @@ -447,81 +504,75 @@ def _create_npu_feature_map(serial_feature_map): ) nfm.layout = layout_map[layout] nfm.strides = vapi.NpuShape3D( - int(serial_feature_map.stride_h.value), - int(serial_feature_map.stride_w.value), - int(serial_feature_map.stride_c.value), + int(serial_feature_map.stride_h), + int(serial_feature_map.stride_w), + int(serial_feature_map.stride_c), ) return nfm -def _create_npu_kernel(serial_kernel): +def _create_npu_kernel(serial_kernel: spec.SerialKernel) -> vapi.NpuKernel: """This is a helper function to capture a list - of arguments to create Vela NpuKernel object + of arguments to create Vela NpuKernel object. """ nknl = vapi.NpuKernel( - w=int(serial_kernel.width.value), - h=int(serial_kernel.height.value), - stride_x=int(serial_kernel.stride_w.value), - stride_y=int(serial_kernel.stride_h.value), - dilation_x=int(serial_kernel.dilation_w.value), - dilation_y=int(serial_kernel.dilation_h.value), + w=int(serial_kernel.width), + h=int(serial_kernel.height), + stride_x=int(serial_kernel.stride_w), + stride_y=int(serial_kernel.stride_h), + dilation_x=int(serial_kernel.dilation_w), + dilation_y=int(serial_kernel.dilation_h), ) return nknl -def _create_npu_address_range(serial_address_range): +def _create_npu_address_range( + serial_address_range: spec.SerialAddressRange, +) -> vapi.NpuAddressRange: """This is a helper function to capture a list - of arguments to create Vela NpuAddressRange object + of arguments to create Vela NpuAddressRange object. """ addr_range = vapi.NpuAddressRange( # region will be updated later region=0, address=serial_address_range.address, - length=int(serial_address_range.length.value), + length=int(serial_address_range.length), ) return addr_range def _create_npu_quantization( - scale, - zero_point, -): + scale: Union[tvm.tir.FloatImm, float], + zero_point: Union[tvm.tir.IntImm, int], +) -> vapi.NpuQuantization: """This is a helper function to capture a list - of arguments to create Vela NpuQuantization object + of arguments to create Vela NpuQuantization object. """ - # Scale could be an ndarray if per-channel quantization is available - if not isinstance(scale, tvm.tir.expr.Load): - if isinstance(scale.value, float): - scale = np.single(scale.value) - else: - assert isinstance(scale.value.value, float) - scale = np.single(scale.value.value) - q_params = vapi.NpuQuantization(scale_f32=scale, zero_point=zero_point.value) - return q_params + return vapi.NpuQuantization(scale_f32=float(scale), zero_point=int(zero_point)) def _create_npu_weights_zero_point( - zero_point, -): - """This is a helper function to capture the weights zero point""" - return zero_point.value + zero_point: Union[int, tvm.tir.IntImm], +) -> int: + """This is a helper function to capture the weights zero point.""" + return int(zero_point) -def _create_npu_padding(serial_padding): +def _create_npu_padding(serial_padding: spec.SerialPadding) -> vapi.NpuPadding: """This is a helper function to capture a list - of arguments to create Vela NpuPadding object""" + of arguments to create Vela NpuPadding object.""" padding = vapi.NpuPadding( - top=int(serial_padding.top.value), - left=int(serial_padding.left.value), - bottom=int(serial_padding.bottom.value), - right=int(serial_padding.right.value), + top=int(serial_padding.top), + left=int(serial_padding.left), + bottom=int(serial_padding.bottom), + right=int(serial_padding.right), ) return padding -def _create_npu_activation(serial_activation): +def _create_npu_activation(serial_activation: spec.SerialActivation) -> vapi.NpuActivation: """This is a helper function to capture a list - of arguments to create Vela NpuActivation object""" + of arguments to create Vela NpuActivation object.""" if serial_activation.op == "NONE": return None if ( @@ -538,16 +589,16 @@ def _create_npu_activation(serial_activation): op = str(serial_activation.op.value) assert op in op_map.keys() act_op = vapi.NpuActivation(op_map[op]) - act_op.min = int(serial_activation.clip_min.value) - act_op.max = int(serial_activation.clip_max.value) + act_op.min = int(serial_activation.clip_min) + act_op.max = int(serial_activation.clip_max) return act_op def _create_npu_resampling_mode( - mode, -): + mode: str, +) -> vapi.NpuResamplingMode: """This is a helper function to capture a list - of arguments to create Vela NpuResamplingMode object""" + of arguments to create Vela NpuResamplingMode object.""" mode_map = { "NONE": vapi.NpuResamplingMode.NONE, "NEAREST": vapi.NpuResamplingMode.NEAREST, diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index 5009c3157c77..69095e43416e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -27,6 +27,7 @@ import numpy as np # type: ignore from ethosu.vela import api as vapi # type: ignore +import tvm from tvm.relay.backend.contrib.ethosu import util # type: ignore from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator as tirtocs @@ -45,7 +46,7 @@ def get_optimal_block_config( - npu_op: vapi.NpuOperation, accel_type: vapi.NpuAccelerator + npu_op: vapi.NpuOperation, accel_config: vapi.NpuAccelerator ) -> vapi.NpuShape3D: """ "The NPU's unit of work is known as a block. It will fetch block(s) from Input @@ -58,15 +59,15 @@ def get_optimal_block_config( ---------- npu_op : ethosu.vela.api.NpuOperation The NPU operation and its params - accel_type : ethosu.vela.api.NpuAccelerator - The NPU accelerator variant + accel_config : ethosu.vela.api.NpuAccelerator + The NPU accelerator config Returns ------- ethosu.vela.api.NpuShape3D : The optimal block config for the operator """ - all_valid_block_configs = vapi.npu_find_block_configs(npu_op, accel_type) + all_valid_block_configs = vapi.npu_find_block_configs(npu_op, accel_config) return _get_optimal_block_config(all_valid_block_configs) @@ -112,7 +113,9 @@ def _get_optimal_block_config(all_valid_block_configs: List[vapi.NpuShape3D]) -> return max_area_depth_block_configs[0] -def encode_weights(tir_extern_call, values, accel_type): +def encode_weights( + tir_extern_call: tvm.tir.Call, values: np.ndarray, accel_config: vapi.NpuAccelerator +): """This is an API function to compress weights by passing a tir_extern_call to NPU Convolution operation and values. @@ -122,26 +125,30 @@ def encode_weights(tir_extern_call, values, accel_type): tir_extern_call to NPU Convolution operation values : numpy.ndarray The constant flattened weight data in OHWI layout - accel_type : ethosu.vela.api.NpuAccelerator - The NPU accelerator variant + accel_config : ethosu.vela.api.NpuAccelerator + The NPU accelerator config Returns ------- bytearray Compressed weights """ - supported_ops = ["ethosu_conv2d"] + supported_ops = { + "ethosu_conv2d": tirtocs.translate_ethosu_conv2d, + "ethosu_depthwise_conv2d": tirtocs.translate_ethosu_depthwise_conv2d, + } op = str(tir_extern_call.args[0].value) - assert op in supported_ops - npu_op, weights_zero_point = tirtocs.translate_ethosu_conv2d(tir_extern_call) - block_config = get_optimal_block_config(npu_op, accel_type) + assert op in supported_ops.keys() + npu_op, weights_zero_point = supported_ops[op](tir_extern_call) + block_config = get_optimal_block_config(npu_op, accel_config) # The weight layout is assumed to be flat OHWI, always. assert len(values.shape) == 1 + is_depthwise = op == "ethosu_depthwise_conv2d" shape_ohwi = ( npu_op.ofm.shape.depth, npu_op.kernel.height, npu_op.kernel.width, - npu_op.ifm.shape.depth, + 1 if is_depthwise else npu_op.ifm.shape.depth, ) assert values.size == np.prod(shape_ohwi) values = np.reshape(values, shape_ohwi) @@ -153,9 +160,8 @@ def encode_weights(tir_extern_call, values, accel_type): ifm_bitdepth=npu_op.ifm.data_type.size_in_bits(), block_depth=block_config.depth, dilation=(npu_op.kernel.dilation_x, npu_op.kernel.dilation_y), - accel_type=accel_type, - # TODO(@manupa-arm): change this when we support depthwise - is_depthwise=False, + accel_config=accel_config, + is_depthwise=is_depthwise, ) @@ -166,7 +172,7 @@ def compress_weights( ifm_bitdepth: int, block_depth: int, dilation: Tuple[int, int], - accel_type: vapi.NpuAccelerator, + accel_config: vapi.NpuAccelerator, is_depthwise: Optional[bool] = False, ) -> bytearray: """The NPU requires the weights to be compressed @@ -188,8 +194,8 @@ def compress_weights( The depth of the optimal block config for the operator dilation : tuple A tuple of 2 elements indicating dilation in h and w - accel_type : ethosu.vela.api.NpuAccelerator - The NPU accelerator variant + accel_config : ethosu.vela.api.NpuAccelerator + The NPU accelerator config is_depthwise : bool, Optional This indicates whether the weights are compressed for depthwise convolution @@ -212,7 +218,7 @@ def compress_weights( ] block_traversal = calculate_block_traversal_mode(is_depthwise, shape_ohwi, ifm_bitdepth) compressed_weights = vapi.npu_encode_weights( - accelerator=accel_type, + accelerator=accel_config, weights_volume=weights_ohwi, dilation_xy=dilation, ifm_bitdepth=ifm_bitdepth, @@ -358,15 +364,24 @@ def _calculate_hw_bias_scales( return hw_bias_scales -def get_target_accel_type(): - """This is a helper function to convert cli accelerator type str argument - to NpuAccelerator""" +def get_accelerator_config() -> vapi.NpuAccelerator: + """Get the configuration of the NPU accelerator. + + The configuration string provided as a compiler option is converted into + an NpuAccelerator object. Valid configuration strings: + - 'ethos-u55-256' + - 'ethos-u55-128' + - 'ethos-u55-64' + - 'ethos-u55-32' + + """ npu_accel_str_map = { "ethos-u55-256": vapi.NpuAccelerator.Ethos_U55_256, "ethos-u55-128": vapi.NpuAccelerator.Ethos_U55_128, "ethos-u55-64": vapi.NpuAccelerator.Ethos_U55_64, "ethos-u55-32": vapi.NpuAccelerator.Ethos_U55_32, } - accel_type_str = util.get_accelerator_config() - assert accel_type_str in npu_accel_str_map.keys(), f"{accel_type_str} is not supported" - return npu_accel_str_map[accel_type_str] + compiler_attrs = tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")() + accel_config_str = compiler_attrs.accelerator_config + assert accel_config_str in npu_accel_str_map.keys(), f"{accel_config_str} is not supported" + return npu_accel_str_map[accel_config_str] diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/te_compiler.py similarity index 79% rename from python/tvm/relay/backend/compile_engine.py rename to python/tvm/relay/backend/te_compiler.py index e9129db7b200..db7504915887 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/te_compiler.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=len-as-condition,no-else-return,invalid-name -"""Backend code generation engine.""" +"""TE compiler engine (replacing legacy compile_engine).""" from __future__ import absolute_import import logging -import numpy as np import tvm from tvm import te, autotvm from tvm.ir.transform import PassContext @@ -31,7 +30,7 @@ from .. import ty as _ty from . import _backend -logger = logging.getLogger("compile_engine") +logger = logging.getLogger("te_compiler") autotvm_logger = logging.getLogger("autotvm") _first_warning = True @@ -47,7 +46,7 @@ def __init__(self, outputs, implement): @tvm._ffi.register_object("relay.CCacheKey") class CCacheKey(Object): - """Key in the CompileEngine. + """Key in the TE Compiler. Parameters ---------- @@ -64,7 +63,7 @@ def __init__(self, source_func, target): @tvm._ffi.register_object("relay.CCacheValue") class CCacheValue(Object): - """Value in the CompileEngine, including usage statistics.""" + """Value in the TE Compiler, including usage statistics.""" def _get_cache_key(source_func, target): @@ -79,24 +78,6 @@ def _get_cache_key(source_func, target): return source_func -def get_shape(shape): - """Convert the shape to correct dtype and vars.""" - ret = [] - for dim in shape: - if isinstance(dim, tvm.tir.IntImm): - if libinfo()["INDEX_DEFAULT_I64"] == "ON": - ret.append(dim) - else: - val = int(dim) - assert val <= np.iinfo(np.int32).max - ret.append(tvm.tir.IntImm("int32", val)) - elif isinstance(dim, tvm.tir.Any): - ret.append(te.var("any_dim", "int32")) - else: - ret.append(dim) - return ret - - def get_valid_implementations(op, attrs, inputs, out_type, target): """Get all valid implementations from the op strategy. @@ -275,6 +256,24 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) return best_plevel_impl, outputs[best_plevel_impl] +def get_shape(shape): + """Convert the shape to correct dtype and vars.""" + ret = [] + for dim in shape: + if isinstance(dim, tvm.tir.IntImm): + if libinfo()["INDEX_DEFAULT_I64"] == "ON": + ret.append(dim) + else: + val = int(dim) + assert val <= np.iinfo(np.int32).max + ret.append(tvm.tir.IntImm("int32", val)) + elif isinstance(dim, tvm.tir.Any): + ret.append(te.var("any_dim", "int32")) + else: + ret.append(dim) + return ret + + @tvm._ffi.register_func("relay.backend.lower_call") def lower_call(call, inputs, target): """Lower the call expression to op implementation and tensor outputs.""" @@ -322,12 +321,12 @@ def lower_call(call, inputs, target): return LoweredOutput(outputs, best_impl) -@tvm._ffi.register_object("relay.CompileEngine") -class CompileEngine(Object): - """CompileEngine to get lowered code.""" +@tvm._ffi.register_object("relay.TECompiler") +class TECompiler(Object): + """TECompiler to get lowered code.""" def __init__(self): - raise RuntimeError("Cannot construct a CompileEngine") + raise RuntimeError("Cannot construct a TECompiler") def lower(self, source_func, target=None, mod_name="default"): """Lower a source_func to a CachedFunc. @@ -349,7 +348,7 @@ def lower(self, source_func, target=None, mod_name="default"): try: mod_name = mangle_module_name(mod_name) key = _get_cache_key(source_func, target) - return _backend._CompileEngineLower(self, key, mod_name) + return _backend._TECompilerLower(self, key, mod_name) except Exception: import traceback @@ -360,10 +359,6 @@ def lower(self, source_func, target=None, mod_name="default"): msg += "--------------------------\n" raise RuntimeError(msg) - def lower_shape_func(self, source_func, target=None): - key = _get_cache_key(source_func, target) - return _backend._CompileEngineLowerShapeFunc(self, key) - def jit(self, source_func, target=None): """JIT a source_func to a tvm.runtime.PackedFunc. @@ -381,87 +376,30 @@ def jit(self, source_func, target=None): The result of jited function. """ key = _get_cache_key(source_func, target) - return _backend._CompileEngineJIT(self, key) + return _backend._TECompilerJIT(self, key) def clear(self): """clear the existing cached functions""" - _backend._CompileEngineClear(self) + _backend._TECompilerClear(self) def items(self): """List items in the cache. - Returns ------- item_list : List[Tuple[CCacheKey, CCacheValue]] The list of items. """ - res = _backend._CompileEngineListItems(self) - assert len(res) % 2 == 0 - return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)] - - def shape_func_items(self): - """List items in the shape_func_cache. - - Returns - ------- - item_list : List[Tuple[CCacheKey, CCacheValue]] - The list of shape_func_items. - """ - res = _backend._CompileEngineListShapeFuncItems(self) + res = _backend._TECompilerListItems(self) assert len(res) % 2 == 0 return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)] - def get_current_ccache_key(self): - return _backend._CompileEngineGetCurrentCCacheKey(self) - - def dump(self): - """Return a string representation of engine dump. - - Returns - ------- - dump : str - The dumped string representation - """ - items = self.items() - res = "====================================\n" - res += "CompilerEngine dump, %d items cached\n" % len(items) - for k, v in items: - res += "------------------------------------\n" - res += "target={}\n".format(k.target) - res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) - res += "----relay function----\n" - res += k.source_func.astext() + "\n" - res += "----tir function----- \n" - res += "inputs={}\n".format(v.cached_func.inputs) - res += "outputs={}\n".format(v.cached_func.outputs) - res += "function: \n" - res += v.cached_func.funcs.astext() + "\n" - res += "===================================\n" - shape_func_items = self.shape_func_items() - res += "%d shape_func_items cached\n" % len(shape_func_items) - for k, v in shape_func_items: - res += "------------------------------------\n" - res += "target={}\n".format(k.target) - res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) - res += "----relay function----\n" - res += k.source_func.astext() + "\n" - res += "----tir function----- \n" - res += "inputs={}\n".format(v.cached_func.inputs) - res += "outputs={}\n".format(v.cached_func.outputs) - res += "function: \n" - res += v.cached_func.funcs.astext() + "\n" - res += "===================================\n" - return res - def get(): - """Get the global compile engine. + """Get the global TE Compiler. Returns ------- - engine : tvm.relay.backend.CompileEngine - The compile engine. + engine : tvm.relay.backend.TECompiler + The TE Compiler. """ - return _backend._CompileEngineGlobal() + return _backend._TECompilerGlobal() diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index c67ac1dc423d..f1686d2a03bb 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -123,7 +123,7 @@ def build( to setup the dimensions and parameters correctly. target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, - otherwise a stackvm intepreter is used. + otherwise a stackvm interpreter is used. params : dict of str to NDArray Input parameters to the graph that do not change @@ -303,7 +303,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" setup the dimensions and parameters correctly. target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, - otherwise a stackvm intepreter is used. + otherwise a stackvm interpreter is used. params : dict of str to NDArray Input parameters to the graph that do not change @@ -452,7 +452,7 @@ def bind_params_by_name(func, params): class GraphExecutor(_interpreter.Executor): """Wrapper around Executor interface. - This executor is used for debug and testing purpoes. + This executor is used for debug and testing purposes. Parameters ---------- diff --git a/python/tvm/relay/frontend/caffe.py b/python/tvm/relay/frontend/caffe.py index b8273b0324c0..be76feef7297 100644 --- a/python/tvm/relay/frontend/caffe.py +++ b/python/tvm/relay/frontend/caffe.py @@ -50,6 +50,7 @@ def __init__(self, init_layer_dict, predict_layer, exp_tab): "Deconvolution": self.convert_deconv, "Dropout": self.convert_dropout, "Eltwise": self.convert_eltwise, + "Embed": self.convert_embed, "Flatten": self.convert_flatten, "InnerProduct": self.convert_innerproduct, "Input": None, @@ -593,6 +594,46 @@ def convert_crop(self, op): out = _op.slice_like(in_expr_a_stride, in_expr_b, axes=to_crop_axis) return out + def convert_embed(self, op): + """Convert Embed layer""" + inputs = op.bottom + embed_param = op.embed_param + num_output = embed_param.num_output + input_dim = embed_param.input_dim + bias_term = embed_param.bias_term + weight_bias_blobs = self.init_layer_dict[op.name].blobs + weight, bias = None, None + if bias_term: + weight = weight_bias_blobs[0] + bias = weight_bias_blobs[1] + assert weight and bias + else: + weight = weight_bias_blobs[0] + assert weight + weight_value = np.asarray(weight.data, np.float32) + weight_value = np.reshape(weight_value, [input_dim, num_output]) + weight_expr = self.exp_tab.new_const(weight_value, dtype="float32") + in_expr = self.exp_tab.get_expr(inputs[0]) + input_shape = _infer_shape(in_expr) + input_count = 1 + for dim in input_shape: + input_count *= dim + + index = _op.cast(in_expr, "int32") + out = _op.take(weight_expr, index, axis=0) + + if bias_term: + bias_value = np.asarray(bias.data, np.float32) + bias_expr = self.exp_tab.new_const(bias_value, dtype="float32") + out = _op.reshape(out, [input_count, num_output]) + out = _op.add(out, bias_expr) + + out_shape = list(input_shape) + out_shape.append(num_output) + out = _op.reshape(out, out_shape) + + return out + def check_unsupported_ops(self): """Check unsupported Caffe ops in our converter.""" unsupported_ops_set = set() diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 3a4897ad3166..cf579923e301 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -28,6 +28,7 @@ from .. import function as _function from .. import transform as _transform from .. import op as _op +from .. import ty as _ty from .. import analysis # pylint: disable=invalid-name @@ -594,6 +595,16 @@ def try_infer_value(val, on_success=None, on_failure=None): return val, False +def shape_of(x, dtype="int64"): + """Get shape of a tensor.""" + + ttype = infer_type(x).checked_type + if not _ty.is_dynamic(ttype): + shape = list(ttype.shape) + return _expr.const(shape, dtype) + return _op.shape_of(x, dtype) + + def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"): return _expr.var(name_hint, type_annotation, shape, dtype) @@ -835,3 +846,97 @@ def lstm_cell( outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)] return outputs_list, hidden_state, cell_state + + +def autopad( + data, + strides, + kernel_shape, + dilations=(1, 1), + pad_type="constant", + deconv=False, + mode="SAME_UPPER", + pad_value=0.0, +): + """ + Perform autopadding with dynamic input shapes + """ + # get attributes as constants + strides = _op.const(np.array(strides), dtype="int64") + dilated_kernel_shape = _op.const( + np.array( + [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] + ), + dtype="int64", + ) + # get input shape + ndim = len(infer_shape(data)) + shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) + + # set up integer constants + zero = _op.const(0, dtype="int64") + one = _op.const(1, dtype="int64") + two = _op.const(2, dtype="int64") + + # Calculate total padding + mod = _op.mod(shape, strides) + + left = _op.maximum(dilated_kernel_shape - strides, zero) + right = _op.maximum(dilated_kernel_shape - mod, zero) + + total_pad = _op.where(_op.equal(mod, zero), left, right) + if deconv: + total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad + + # split total padding into before and after + pad_before = _op.floor_divide(total_pad, two) + pad_after = total_pad - pad_before + + # combine + if "LOWER" in mode: + pad = _op.concatenate( + [_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1 + ) + else: + pad = _op.concatenate( + [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 + ) + + # pad N and C with zeros + pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) + + if isinstance(pad_value, (float, int)): + pad_value = _op.const(pad_value) + + return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) + + +def ensure_scalar_shape(x): + """ + Assume that `x` is a tensor with one element (regardless of tensor rank). + Return a version of that tensor with rank 0. + """ + x_shape = infer_shape(x) + x_rank = len(x_shape) + + if x_rank == 0: + return x + + num_elem = np.prod(x_shape) + assert num_elem == 1, "Cannot squeeze tensor shape {} to scalar form.".format(x_shape) + + return _op.squeeze(x) + + +def try_resolve_var_to_const(x, graph_params): + """ + Try to resolve the value of tensor `x` to a specific value. + If successful, return a Const op with that value. + If unsuccessful, simply return `x`. + """ + if isinstance(x, _expr.Var) and x.name_hint in graph_params: + value = graph_params[x.name_hint].numpy() + dtype = infer_type(x).checked_type.dtype + return _op.const(value, dtype) + + return x diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index aa185923d02e..bf6293a2a90c 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -896,6 +896,7 @@ def _convert_lstm(inexpr, keras_layer, etab): in_data = _op.squeeze(in_data, axis=[0]) in_data = _op.split(in_data, indices_or_sections=time_steps, axis=0) # loop for the number of time_steps + out_list = [] # store h outputs in case return_sequences is True for data in in_data: ixh1 = _op.nn.dense(data, kernel_weight, units=units) ixh2 = _op.nn.bias_add(_op.nn.dense(next_h, recurrent_weight, units=units), bias=in_bias) @@ -906,8 +907,11 @@ def _convert_lstm(inexpr, keras_layer, etab): next_c = in_transform * next_c + in_gate * _convert_activation(gates[2], keras_layer, None) out_gate = _convert_recurrent_activation(gates[3], keras_layer) next_h = out_gate * _convert_activation(next_c, keras_layer, None) + if keras_layer.return_sequences: + out_list.append(_op.expand_dims(next_h, axis=1)) + out = _op.concatenate(out_list, axis=1) if keras_layer.return_sequences else next_h out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) - out = _op.reshape(next_h, newshape=out_shape) + out = _op.reshape(out, newshape=out_shape) return [out, next_h, next_c] diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 86cb178d0875..3c88f659f6f0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -38,8 +38,10 @@ from .. import ty as _ty from .. import vision as _vision from .common import ( + autopad, AttrCvt, Renamer, + ensure_scalar_shape, fold_constant, get_name, get_relay_op, @@ -50,6 +52,8 @@ infer_value, lstm_cell, new_var, + shape_of, + try_resolve_var_to_const, unbind, ) @@ -313,7 +317,6 @@ def _run_calculation(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], [1] * ndim, - ndim, pad_value=pad_val, mode=attr["auto_pad"], ) @@ -409,69 +412,6 @@ def _impl_v1(cls, inputs, attr, params): return AttrCvt(op_name="instance_norm")(inputs, attr, params) -def autopad( - data, - strides, - kernel_shape, - dilations, - ndim, - pad_type="constant", - deconv=False, - mode="SAME_UPPER", - pad_value=0.0, -): - """ - Perform autopadding with dynamic input shapes - """ - # get attributes as constants - strides = _op.const(np.array(strides), dtype="int64") - dilated_kernel_shape = _op.const( - np.array( - [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] - ), - dtype="int64", - ) - # get input shape - shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) - - # set up integer constants - zero = _op.const(0, dtype="int64") - one = _op.const(1, dtype="int64") - two = _op.const(2, dtype="int64") - - # Calculate total padding - mod = _op.mod(shape, strides) - - left = _op.maximum(dilated_kernel_shape - strides, zero) - right = _op.maximum(dilated_kernel_shape - mod, zero) - - total_pad = _op.where(_op.equal(mod, zero), left, right) - if deconv: - total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad - - # split total padding into before and after - pad_before = _op.floor_divide(total_pad, two) - pad_after = total_pad - pad_before - - # combine - if "LOWER" in mode: - pad = _op.concatenate( - [_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1 - ) - else: - pad = _op.concatenate( - [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 - ) - - # pad N and C with zeros - pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) - - if isinstance(pad_value, (float, int)): - pad_value = _op.const(pad_value) - - return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) - - class Conv(OnnxOpConverter): """Operator converter for Conv.""" @@ -499,7 +439,6 @@ def _impl_v1(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, mode=attr["auto_pad"], ) elif attr["auto_pad"] == "VALID": @@ -580,7 +519,6 @@ def _impl_v1(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, deconv=True, mode=attr["auto_pad"], ) @@ -972,7 +910,6 @@ def _impl_v1(cls, inputs, attr, params): attr["strides"], attr["kernel_shape"], [1] * ndim, - ndim, mode=attr["auto_pad"], ) elif attr["auto_pad"] == "VALID": @@ -1408,14 +1345,6 @@ def _impl_v9(cls, inputs, attr, params): return out -def shape_of(x, dtype="int64"): - ttype = infer_type(x).checked_type - if not _ty.is_dynamic(ttype): - shape = list(ttype.shape) - return _expr.const(shape, dtype) - return _op.shape_of(x, dtype) - - class Shape(OnnxOpConverter): """Operator converter for Shape.""" @@ -2695,6 +2624,40 @@ def _impl_v10(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): + scale = inputs[2] + scale_shape = infer_shape(scale) + if len(inputs) == 4: + assert ( + len(scale_shape) == 0 or scale_shape[0] == 0 + ), "One of scale or size should be passed, not both." + size = inputs[3] + else: + assert len(scale_shape) != 0, "One of scale or size should be passed." + size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale + return cls.v11_13_common(inputs, size, attr, params) + + @classmethod + def _impl_v13(cls, inputs, attr, params): + scale = inputs[2] + size = inputs[3] + if size is not None: + assert scale is None, "One of scale or size should be passed, not both." + else: + scale_type = infer_type(scale) + scale_shape = scale_type.checked_type.shape + scale_dtype = scale_type.checked_type.dtype + assert len(scale_shape) != 0, "One of scale or size should be passed." + size = _op.cast(shape_of(inputs[0]), scale_dtype) * scale + + return cls.v11_13_common(inputs, size, attr, params) + + @classmethod + def v11_13_common(cls, inputs, size, attr, params): + """ + Resize v11 and Resize v13 are identical except in how + they handle the passing of scale and size. This utility + provides the implementation for both + """ ndims = len(infer_shape(inputs[0])) mode = attr.get("mode").decode("ascii") if mode == "nearest": @@ -2713,16 +2676,6 @@ def _impl_v11(cls, inputs, attr, params): alpha = attr.get("cubic_coeff_a", -0.75) exclude = attr.get("exclude_outside", 0) - scale = inputs[2] - scale_shape = infer_shape(scale) - if len(inputs) == 4: - assert ( - len(scale_shape) == 0 or scale_shape[0] == 0 - ), "One of scale or size should be passed, not both." - size = inputs[3] - else: - assert len(scale_shape) != 0, "One of scale or size should be passed." - size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale out_size = fold_constant(_op.strided_slice(size, [2], [4])) out = None if ndims == 3: @@ -3414,7 +3367,6 @@ def _impl_v10(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, pad_value=x_zero_point.data, mode=attr["auto_pad"], ) @@ -3506,6 +3458,156 @@ def _impl_v10(cls, inputs, attr, params): return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype) +class QLinearMatMul(OnnxOpConverter): + """ + Operator converter for QLinearMatMul from Microsoft onnxruntime contrib opset. + + Limitations: + - Only supports 2D input tensors. + - Not guaranteed to meet the integer-overflow behavior stipulated in the + ONNX documentation for this operator. + """ + + @classmethod + def _impl_v10(cls, inputs, attr, params): + + # Some of the ops used below take scalar-like inputs, and may require either + # of the following: + # + # - the input is Const node (not merely an expression that *could* be reduced + # to a single Const at graph-compilation time) + # + # - the input has a specific dtype + # + # This function attempts to present 'x' in a form that meets both of those + # requirements. + def try_resolve_to_const_scalar(x, dtype_override=None): + x2 = try_resolve_var_to_const(x, params) + x3 = ensure_scalar_shape(x2) + + x_dtype = infer_type(x).checked_type.dtype + if (dtype_override is not None) and (dtype_override != x_dtype): + x4 = _op.cast(x3, dtype_override) + else: + x4 = x3 + + x5 = fold_constant(x4) + return x5 + + # Unpack the inputs and obtain some type info... + a, a_scale, a_zp, b, b_scale, b_zp, y_scale, y_zp = inputs + + a_type = infer_type(a).checked_type # 'T1' in ONNX doc for this op + a_scale_type = infer_type(a_scale).checked_type + a_zp_type = infer_type(a_zp).checked_type + + b_type = infer_type(b).checked_type # 'T2' in ONNX doc for this op + b_scale_type = infer_type(b_scale).checked_type + b_zp_type = infer_type(b_zp).checked_type + + y_scale_type = infer_type(y_scale).checked_type + y_zp_type = infer_type(y_zp).checked_type # 'T3' in ONNX doc for this op + + a_shape = infer_shape(a) + b_shape = infer_shape(b) + + # Verify type assumptions, based on the ONNX doc for this op... + assert a_type.dtype in ["int8", "uint8"] + assert a_scale_type.dtype == "float32" + assert a_zp_type.dtype == a_type.dtype + + assert b_type.dtype in ["int8", "uint8"] + assert b_scale_type.dtype == "float32" + assert b_zp_type.dtype == b_type.dtype + + assert y_scale_type.dtype == "float32" + assert y_zp_type.dtype in ["int8", "uint8"] + + # TODO: relax this limitation in a future version of this importer. + a_rank = len(a_shape) + b_rank = len(b_shape) + assert (a_rank == 2) and (b_rank == 2), ( + "QLinearMatMul importer currently requires both 'a' and 'b' tensors to be 2D, but" + " rank(a)={}, rank(b)={}".format(a_rank, b_rank) + ) + + # _qnn.op.dense requires the zero-point values to have dtype int32. + a_scale_scalar = try_resolve_to_const_scalar(a_scale) + a_zp_scalar = try_resolve_to_const_scalar(a_zp, "int32") + + b_scale_scalar = try_resolve_to_const_scalar(b_scale) + b_zp_scalar = try_resolve_to_const_scalar(b_zp, "int32") + + y_scale_scalar = try_resolve_to_const_scalar(y_scale) + y_zp_scalar = try_resolve_to_const_scalar(y_zp, "int32") + + # TODO: Confirm that we're using 'num_hidden_units' correctly / as intended with + # the '_qnn.op.dense' instance below. + num_hidden_units = infer_shape(b)[-1] + + # - Specify the matmul result dtype as int32, so that hopefully the matmul will use + # a 32-bit accumulator as seems to be required by the ONNX op's documentation. + # + # TL;DR: + # The ONNX documentation for this op is clear about acceptable overflow + # behavior during the matmul operation: + # - The scalar multiplication ops MAY NOT overflow. + # - The scalar addition ops, which sum the results of the scalar multiplication, + # MAY overflow, but if they do so, it must behave as one would expect during + # 32-bit integer-addition overflow. + # As of this writing, Relay's qnn.op.dense operator doesn't expose a way for us to + # express these constraints. + # + # TODO: Extend TVM / Relay / TIR / etc. to allow this kind of constraint to be + # expressed in a Relay graph. And then update this importer and various TVM + # backends accordingly. + matmul_result_dtype = "int32" + + matmul_result = _qnn.op.dense( + a, + _op.transpose(b), + a_zp_scalar, + b_zp_scalar, + a_scale_scalar, + b_scale_scalar, + num_hidden_units, + matmul_result_dtype, + ) + + # This information might only be found in the C++ code-comments for the + # dense.matmul op, but the quantized tensor returned by _qnn.op.dense + # has scale==(a_scale_scalar * b_scale_scalar), and zero_point==0. + # + # 'matmul_result_zp_scalar' has type 'int32' to satisfy input requirements + # of the [de/re]quantize ops below. + matmul_result_scale_scalar = fold_constant(_op.multiply(a_scale_scalar, b_scale_scalar)) + matmul_result_zp_scalar = _op.const(0, dtype="int32") + + # requantize requires y_scale to be constant, + # if y_scale is not constant, doing dequantize -> quantize + if isinstance(y_scale_scalar, _expr.Constant): + y = _qnn.op.requantize( + matmul_result, + matmul_result_scale_scalar, + matmul_result_zp_scalar, + y_scale_scalar, + y_zp_scalar, + axis=-1, + rounding="TONEAREST", + out_dtype=y_zp_type.dtype, + ) + else: + matmul_result_deq = _qnn.op.dequantize( + matmul_result, matmul_result_scale_scalar, matmul_result_zp_scalar, axis=0 + ) + + y = _qnn.op.quantize( + matmul_result_deq, y_scale_scalar, y_zp_scalar, axis=0, out_dtype=y_zp_type.dtype + ) + + return y + + class QLinearMul(OnnxOpConverter): """Operator converter for QLinearMul from Microsoft onnxruntime contrib opset.""" @@ -3634,7 +3736,6 @@ def _impl_v10(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, pad_value=data_zp, mode=attr["auto_pad"], ) @@ -4234,6 +4335,7 @@ def _get_convert_map(opset): "QLinearConv": QLinearConv.get_converter(opset), "QLinearConcat": QLinearConcat.get_converter(opset), "QLinearAdd": QLinearAdd.get_converter(opset), + "QLinearMatMul": QLinearMatMul.get_converter(opset), "QLinearMul": QLinearMul.get_converter(opset), "QLinearSigmoid": QLinearSigmoid.get_converter(opset), "ConvInteger": ConvInteger.get_converter(opset), diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 378002a74416..ef361d6c55e8 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -18,11 +18,13 @@ # pylint: disable=import-outside-toplevel """Paddle: PArallel Distributed Deep LEarning.""" +import warnings import numpy as np import tvm from tvm.ir import IRModule +from ... import nd as _nd from .. import analysis from .. import ty as _ty from .. import expr as _expr @@ -30,11 +32,13 @@ from .. import ty as _ty from .. import op as _op from .common import ( + autopad, fold_constant, get_relay_op, infer_shape, infer_type, infer_value, + shape_of, try_infer_value, new_var, ) @@ -42,20 +46,6 @@ __all__ = ["from_paddle"] -def _get_pad_size(in_size, dilated_kernel_size, stride_size): - """Calculate the paddings size for Conv/Pool in SAME padding mode.""" - - if stride_size == 1 or in_size % stride_size == 0: - pad = max(dilated_kernel_size - stride_size, 0) - else: - pad = max(dilated_kernel_size - (in_size % stride_size), 0) - - pad_before = pad // 2 - pad_after = pad - pad_before - - return [pad_before, pad_after] - - def _dtype_shape_promotion(inputs): """Promote data type and shape for list of tensors.""" @@ -77,16 +67,6 @@ def _dtype_shape_promotion(inputs): return inputs -def shape_of(x, dtype="int32"): - """Get shape of a tensor.""" - - ttype = infer_type(x).checked_type - if not _ty.is_dynamic(ttype): - shape = list(ttype.shape) - return _expr.const(np.array(shape), dtype) - return _op.shape_of(x, dtype) - - def _convert_dtype_value(val): """Converts a Paddle type id to a string.""" @@ -247,24 +227,16 @@ def convert_conv2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - if strides[0] == 1 and strides[1] == 1: - pad_h = _get_pad_size(0, (k_h - 1) * dilations[0] + 1, strides[0]) - pad_w = _get_pad_size(0, (k_w - 1) * dilations[1] + 1, strides[1]) - else: - input_shape = shape_of(input_x) - h_w = _op.strided_slice(input_shape, [2], [4]) - try: - in_h, in_w = infer_value(h_w, g.get_params()).numpy().tolist() - except Exception as e: - msg = "Dynamic shape is not supported in SAME padding algorithm while stride!=1" - raise tvm.error.OpAttributeInvalid(msg) from e - pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) - pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) - paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + # Handle history issue of PaddlePaddle + # while padding_algorithm == "SAME" + # dilations will be set to [1, 1] + dilations = [1, 1] + input_x = autopad(input_x, strides, [k_h, k_w], dilations) + paddings = [0, 0] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] - if len(paddings) == 4: + elif len(paddings) == 4: paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] else: msg = 'Value {} in attribute "padding" of operator Conv is not "valid."' @@ -558,9 +530,9 @@ def convert_matmul(g, op, block): # This implemention almost keeps same with ONNX # Need to check input shape as batch matmul must be supported. - a_shape = shape_of(inputs[0]) + a_shape = shape_of(inputs[0], dtype="int32") a_rank = infer_shape(a_shape)[0] - b_shape = shape_of(inputs[1]) + b_shape = shape_of(inputs[1], dtype="int32") b_rank = infer_shape(b_shape)[0] # When performing a batch matmul, we need to properly handle N-dim shapes. if a_rank > 2 or b_rank > 2: @@ -647,8 +619,8 @@ def convert_mul(g, op, block): y = g.get_node(op.input("Y")[0]) x_num_col_dims = op.attr("x_num_col_dims") y_num_col_dims = op.attr("y_num_col_dims") - x_shape = shape_of(x) - y_shape = shape_of(y) + x_shape = shape_of(x, dtype="int32") + y_shape = shape_of(y, dtype="int32") x_dim = infer_shape(x_shape)[0] y_dim = infer_shape(y_shape)[0] if x_num_col_dims < 0: @@ -685,6 +657,39 @@ def convert_mul(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_padding(g, op, block): + """Operator converter for padding.""" + + input_x = g.get_node(op.input("X")[0]) + input_padding = op.input("Paddings") + if input_padding: + padding = g.get_node(input_padding[0]) + padding = infer_value(padding, g.get_params()).numpy().tolist() + else: + padding = op.attr("paddings") + padding = op.attr("paddings") + value = op.attr("value") + data_format = op.attr("data_format") + mode = op.attr("mode") + assert mode != "circular", "Don't support mod='circular' for PaddlePaddle's padding" + if mode == "replicate": + mode = "edge" + + pad_len = len(padding) + new_paddings = [0] * (pad_len + 4) + for i in range(0, pad_len, 2): + index = -1 - i + if data_format[:2] != "NC": + index = -3 - i + new_paddings[index] = padding[i + 1] + new_paddings[index - 1] = padding[i] + + new_paddings = [new_paddings[i : i + 2] for i in range(0, len(new_paddings), 2)] + + out = _op.nn.pad(input_x, new_paddings, pad_value=value, pad_mode=mode) + g.add_node(op.output("Out")[0], out) + + def convert_pool2d(g, op, block): """Operator converter for pool2d.""" @@ -695,17 +700,19 @@ def convert_pool2d(g, op, block): paddings = op.attr("paddings") padding_algorithm = op.attr("padding_algorithm") pooling_type = op.attr("pooling_type") + if global_pooling: adaptive = True ksize = [1, 1] input_x = g.get_node(op.input("X")[0]) - in_h, in_w = infer_shape(input_x)[2:] + _, _, in_h, in_w = infer_shape(input_x) op_map = { "avg": "avg_pool2d", "max": "max_pool2d", } + strides = op.attr("strides") if isinstance(strides, int): strides = [strides, strides] @@ -717,22 +724,40 @@ def convert_pool2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - pad_h = _get_pad_size(in_h, ksize[0], strides[0]) - pad_w = _get_pad_size(in_w, ksize[1], strides[1]) - paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + input_x = autopad(input_x, strides, ksize) + paddings = [0, 0] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] - if len(paddings) == 4: + elif len(paddings) == 4: paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] else: msg = 'Value {} in attribute "padding" of operator Pool2d is not "valid."' raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + # handle with special case + # while kernel size less than input size + # shrink kernel size to input size + if not isinstance(in_h, _op.Expr) and in_h < ksize[0]: + ksize[0] = in_h + if not isinstance(in_w, _op.Expr) and in_w < ksize[1]: + ksize[1] = in_w + if not adaptive: - out = getattr(_op.nn, op_map[pooling_type])( - input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode - ) + if pooling_type == "avg": + exclusive = op.attr("exclusive") + out = _op.nn.avg_pool2d( + input_x, + pool_size=ksize, + strides=strides, + padding=paddings, + ceil_mode=ceil_mode, + count_include_pad=not exclusive, + ) + else: + out = getattr(_op.nn, op_map[pooling_type])( + input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode + ) else: out = getattr(_op.nn, "adaptive_" + op_map[pooling_type])(input_x, output_size=ksize) g.add_node(op.output("Out")[0], out) @@ -795,7 +820,7 @@ def convert_shape(g, op, block): """Operator converter for shape.""" x = g.get_node(op.input("Input")[0]) - out = shape_of(x) + out = shape_of(x, dtype="int32") g.add_node(op.output("Out")[0], out) @@ -853,6 +878,17 @@ def convert_softmax(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_squeeze(g, op, block): + """Operator converter for squeeze2.""" + + x = g.get_node(op.input("X")[0]) + axes = op.attr("axes") + if not axes: + axes = None + x = _op.squeeze(x, axis=axes) + g.add_node(op.output("Out")[0], x) + + def convert_unsqueeze(g, op, block): """Operator converter for unsqueeze.""" @@ -903,6 +939,7 @@ def convert_unsqueeze(g, op, block): "matmul": convert_matmul, "matmul_v2": convert_matmul, "mul": convert_mul, + "pad3d": convert_padding, "pool2d": convert_pool2d, "relu": convert_unary_op, "reshape2": convert_reshape, @@ -910,6 +947,7 @@ def convert_unsqueeze(g, op, block): "shape": convert_shape, "slice": convert_slice, "softmax": convert_softmax, + "squeeze2": convert_squeeze, "tanh": convert_unary_op, "unsqueeze2": convert_unsqueeze, } @@ -954,10 +992,12 @@ def extract_parameters(self, program, scope=None): if not var.persistable: continue if isinstance(scope, dict): - self.params[name] = scope[name] + self.params[name] = _nd.array(scope[name]) else: - self.params[name] = np.array(scope.var(name).get_tensor()) - self.nodes[name] = _expr.const(self.params[name]) + self.params[name] = _nd.array(np.array(scope.var(name).get_tensor())) + shape = self.params[name].shape + dtype = self.params[name].dtype + self.nodes[name] = new_var(name, shape=shape, dtype=dtype) def check_input_shape(self, op, block): """Check the shape information of model's inputs, fixed shape is recommended.""" @@ -1048,14 +1088,37 @@ def from_translated_layer(self, layer, shape_dict): free_vars = analysis.free_vars(outputs) func = _function.Function(free_vars, outputs) mod = IRModule.from_expr(func) + # remove unused parameters + final_params = dict() + for var in free_vars: + if var.name_hint in self.params: + final_params[var.name_hint] = self.params[var.name_hint] + self.params = final_params return mod, self.params def from_paddle(program_or_layer, shape_dict=None, scope=None): """Convert a PaddlePaddle model into an equivalent Relay Function. - PaddlePaddle Program/TranslatedLayer represent the computation graph of PaddlePaddle model, and PaddlePaddle scope stores all the weights of PaddlePaddle model. + + Parameters + ---------- + program_or_layer : object of `paddle.static.Program` or `paddle.jit.TranslatedLayer` + Loaded model by `paddle.static.load_inference_model` or `paddle.jit.load` + + shape_dict : dict of str to tuple/list, optional + The input shape of model + + scope : object of `paddle.static.Scope`, optional + The scope that saves all the weights of model, use `paddle.static.global_scope` by default + + Returns + ------- + mod : tvm.IRModule + The relay module for compilation + + params : dict of str to tvm.nd.NDArray """ import paddle diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 76cd0455661b..3fc202a7cc91 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2774,6 +2774,26 @@ def all_any_common(self, op, inputs, input_types): inp = inputs[0] return op(inp, axis=dim, keepdims=keepdim) + def searchsorted_common(self, sorted_sequence, values, out_int32, right): + dtype = "int32" if out_int32 else "int64" + values_shape = _infer_shape(values) + + if len(values_shape) == 0: + values = _op.expand_dims(values, 0) + + out = _op.searchsorted(sorted_sequence, values, right=right, dtype=dtype) + + if len(values_shape) == 0: + return _op.squeeze(out) + + return out + + def searchsorted(self, inputs, input_types): + return self.searchsorted_common(*inputs) + + def bucketize(self, inputs, input_types): + return self.searchsorted_common(inputs[1], inputs[0], inputs[2], inputs[3]) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2999,6 +3019,8 @@ def create_convert_map(self): "aten::lstm": self.lstm, "aten::all": functools.partial(self.all_any_common, _op.all), "aten::any": functools.partial(self.all_any_common, _op.any), + "aten::searchsorted": self.searchsorted, + "aten::bucketize": self.bucketize, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index a66fc4736a98..3688ff5ff4e5 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -66,6 +66,7 @@ def __init__(self, model, subgraph, exp_tab): self.activation_fn_type = build_str_map(ActivationFunctionType()) self.builtin_options = build_str_map(BuiltinOptions()) self.prefetched_nodes = {} + self.allow_custom_ops = False # Add more operators self.convert_map = { @@ -287,6 +288,10 @@ def get_op_code_str(self, op): if op_code_id == BuiltinOperator.CUSTOM: # Custom operator custom_op_code_str = self.model.OperatorCodes(op_code_list_idx).CustomCode() + + if self.allow_custom_ops: + return "CUSTOM" + if custom_op_code_str == b"TFLite_Detection_PostProcess": return "DETECTION_POSTPROCESS" @@ -3695,7 +3700,7 @@ def _input_type(model): return shape_dict, dtype_dict -def from_tflite(model, shape_dict=None, dtype_dict=None): +def from_tflite(model, shape_dict=None, dtype_dict=None, op_converter=OperatorConverter): """Convert from tflite model into compatible relay Function. Parameters @@ -3755,7 +3760,7 @@ def from_tflite(model, shape_dict=None, dtype_dict=None): exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype)) # op code in model - op_converter = OperatorConverter(model, subgraph, exp_tab) + op_converter = op_converter(model, subgraph, exp_tab) op_converter.check_unsupported_ops() op_converter.convert_op_to_relay() diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 817f96b696df..dd1a65288955 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -41,6 +41,10 @@ register_strategy("topk", strategy.topk_strategy) register_pattern("topk", OpPattern.OPAQUE) +# searchsorted +register_strategy("searchsorted", strategy.searchsorted_strategy) +register_pattern("searchsorted", OpPattern.OPAQUE) + @script def _topk_shape_func_input_shape(data_shape, k, axis): @@ -80,3 +84,28 @@ def topk_shape_func(attrs, inputs, _): ret = [indices_out] return ret + + +@script +def _searchsorted_shape(sorted_sequence_shape, values_shape): + out_shape = output_tensor((values_shape.shape[0],), "int64") + if sorted_sequence_shape.shape[0] > 1: + assert ( + sorted_sequence_shape.shape[0] == values_shape.shape[0] + ), "Ranks of `sorted_sequence` and values must be the same if `sorted_sequence` is not 1-D." + for i in range(values_shape.shape[0]): + if sorted_sequence_shape.shape[0] > 1 and i < values_shape.shape[0] - 1: + assert ( + sorted_sequence_shape[i] == values_shape[i] + ), "`sorted_sequence and `values` do not have the same shape along outer axes." + + out_shape[i] = values_shape[i] + return out_shape + + +@_reg.register_shape_func("searchsorted", False) +def searchsorted_shape_func(attrs, inputs, _): + """ + Shape func for searchsorted operator. + """ + return [_searchsorted_shape(inputs[0], inputs[1])] diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 18ce93322f43..daec488bbb94 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -89,9 +89,6 @@ register_broadcast_schedule("fast_exp") register_broadcast_schedule("fast_tanh") register_broadcast_schedule("fast_erf") -# a fake on_device schedule. -# this will not be used in actual computation -register_injective_schedule("on_device") # zeros diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 0284d2483ce5..76c806905b18 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1174,3 +1174,23 @@ def gather_nd_shape_func(attrs, inputs, _): assert index_rank > 0, "index_rank needs to be specified for dynamic gather_nd" return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))] + + +@script +def _gather_shape(data_shape, indices_shape, axis): + out_shape = output_tensor((data_shape.shape[0],), "int64") + for i in range(data_shape.shape[0]): + if i != axis: + assert ( + data_shape[i] == indices_shape[i] + ), "data and indices size at non-gather axes must be the same" + out_shape[i] = indices_shape[i] + return out_shape + + +@_reg.register_shape_func("gather", False) +def gather_shape_func(attrs, inputs, _): + """ + Shape func for gather operator. + """ + return [_gather_shape(inputs[0], inputs[1], attrs.axis)] diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 119936f632f8..809a9061ade0 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -115,3 +115,37 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): if ret_type == "both": return TupleWrapper(out, 2) return out + + +def searchsorted(sorted_sequence, values, right=False, dtype="int32"): + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + Parameters + ---------- + sorted_sequence : relay.Expr + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : relay.Expr + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False, the index of the first suitable location found is given. If true, return the + last such index. If there is no suitable index, return either 0 or N (where N is the + size of the innermost dimension). + + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : relay.Expr + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ + return _make.searchsorted(sorted_sequence, values, right, dtype) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 85ddfd9a7ec8..ca417942840d 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -192,11 +192,11 @@ def __init__(self, func_body: tvm.relay.Function): bias_add = requantize_op.args[0] qnn_conv2d = bias_add.args[0] data_layout = qnn_conv2d.attrs.data_layout - kernel_layout = qnn_conv2d.attrs.kernel_layout + self.kernel_layout = qnn_conv2d.attrs.kernel_layout # We consider the weights & biases as params as it should be a Constant self.weights = TensorParams( qnn_conv2d.args[QConv2DArgs.WEIGHTS.value], - kernel_layout, + self.kernel_layout, qnn_conv2d.args[QConv2DArgs.WEIGHTS_SCALE.value], qnn_conv2d.args[QConv2DArgs.WEIGHTS_ZERO_POINT.value], ) @@ -219,16 +219,18 @@ def __init__(self, func_body: tvm.relay.Function): requantize_op.args[RequantArgs.OFM_SCALE.value], requantize_op.args[RequantArgs.OFM_ZERO_POINT.value], ) - self.padding = qnn_conv2d.attrs.padding - self.strides = qnn_conv2d.attrs.strides - self.dilation = qnn_conv2d.attrs.dilation + attrs = qnn_conv2d.attrs + self.padding = attrs.padding + self.strides = attrs.strides + self.dilation = attrs.dilation self.activation = activation + self.channels = attrs.channels # If groups are equal to channel, its a depthwise_conv2d - self.groups = qnn_conv2d.attrs.groups + self.groups = attrs.groups self.is_depthwise = False channels_axis = {"HWIO": 3, "HWOI": 2} - if qnn_conv2d.attrs.groups == self.weights.shape[channels_axis[kernel_layout]]: + if self.groups == self.weights.shape[channels_axis[self.kernel_layout]]: self.is_depthwise = True def is_valid(self) -> bool: @@ -253,10 +255,52 @@ def is_valid(self) -> bool: legal_groups = [1, self.ofm.shape[3]] if self.groups not in legal_groups: return False - # This should be a valid QnnDepthwise2DParams, not QnnConv2DParams + # This should be a valid QnnDepthwiseConv2DParams, not QnnConv2DParams return not self.is_depthwise +class QnnDepthwiseConv2DParams(QnnConv2DParams): + """ + This class will parse a call to a ethosu.depthwise_conv2d composite function + and extract the parameter information. + """ + + composite_name = "ethosu.depthwise_conv2d" + # The hardware only supports padding upto the numbers as follows + padding_bounds = [31, 31, 32, 32] + + def __init__(self, func_body: tvm.relay.expr.Call): + QnnConv2DParams.__init__(self, func_body) + + def is_valid(self): + """ + Checks whether QnnDepthwiseConv2D + activation function has compatible attributes with HW + """ + tensor_params = [self.weights, self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params): + return False + if not check_weights(self.weights, self.dilation): + return False + if not check_bias(self.biases): + return False + if not check_strides(self.strides): + return False + if not check_batch_size(self.ifm): + return False + if not check_dilation(self.dilation): + return False + if not check_padding(self.padding, self.padding_bounds): + return False + if self.weights.layout != "HWOI": + return False + # only depth multiplier of size 1 is supported + if self.weights.shape[3] != 1: + return False + if not self.is_depthwise: + return False + return True + + def qnn_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ This function creates the pattern for qnn.conv2D with optional fused RELU activation. @@ -266,7 +310,22 @@ def qnn_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: ).has_attr({"kernel_layout": "HWIO"}) bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant()) req = is_op("qnn.requantize")( - qnn_conv2d | bias_add, is_constant(), is_constant(), is_constant(), is_constant() + bias_add, is_constant(), is_constant(), is_constant(), is_constant() + ) + clip_or_req = req.optional(is_op("clip")) + return clip_or_req + + +def qnn_depthwise_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for depthwise qnn.conv2D with optional fused RELU activation. + """ + qnn_conv2d = is_op("qnn.conv2d")( + wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + ).has_attr({"kernel_layout": "HWOI"}) + bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant()) + req = is_op("qnn.requantize")( + bias_add, is_constant(), is_constant(), is_constant(), is_constant() ) clip_or_req = req.optional(is_op("clip")) return clip_or_req @@ -279,7 +338,12 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal QnnConv2DParams.composite_name, qnn_conv2d_pattern(), lambda pat: QnnConv2DParams(pat).is_valid(), - ) + ), + ( + QnnDepthwiseConv2DParams.composite_name, + qnn_depthwise_conv2d_pattern(), + lambda pat: QnnDepthwiseConv2DParams(pat).is_valid(), + ), ] diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index cec7c4d141cb..03bb273c8f92 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -142,6 +142,7 @@ def partition_for_tensorrt( transform.RemoveUnusedFunctions(), transform.ConvertLayout( { + "nn.conv1d": ["NCW", "default"], "nn.conv2d": ["NCHW", "default"], "nn.conv3d": ["NCDHW", "default"], "nn.conv2d_transpose": ["NCHW", "default"], @@ -374,6 +375,23 @@ def softmax_annotate_fn(expr): # pylint: disable=unused-variable return True +@_register_external_dynamic_check_func("nn.conv1d") +def conv1d_annotate_fn(expr): # pylint: disable=unused-variable + """Check if nn.conv1d is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.data_layout != "NCW": + logger.info("nn.conv1d: data_layout is %s but must be NCW.", attrs.data_layout) + return False + if attrs.kernel_layout != "OIW": + logger.info("nn.conv1d: kernel_layout is %s but must be OIW.", attrs.kernel_layout) + return False + return True + + @_register_external_dynamic_check_func("nn.conv2d") def conv2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.conv2d is supported by TensorRT.""" @@ -912,6 +930,7 @@ def __init__(self): def visit_call(self, call): compute_intensive_ops = set( [ + "nn.conv1d", "nn.conv2d", "nn.conv2d_transpose", "nn.conv3d", diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index f06ee09fc7f4..17f75a07af64 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -18,7 +18,7 @@ """Backend compiler related feature registration""" from __future__ import absolute_import -from tvm import topi +from tvm import topi, relay from tvm.topi.utils import get_const_tuple from tvm.runtime import convert @@ -267,9 +267,6 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts): result : tvm.relay.Expr The transformed expr """ - # pylint: disable=import-outside-toplevel - from tvm import relay - data, weight = inputs # First check if there is a LayoutConfig scope, and if so, whether @@ -363,9 +360,6 @@ def convert_conv2d_transpose(attrs, inputs, tinfos, desired_layouts): result : tvm.relay.Expr The transformed expr """ - # pylint: disable=import-outside-toplevel - from tvm import relay - data, weight = inputs new_attrs = dict(attrs) assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs" @@ -446,9 +440,6 @@ def convert_conv3d(attrs, inputs, tinfos, desired_layouts): result : tvm.relay.Expr The transformed expr """ - # pylint: disable=import-outside-toplevel - from tvm import relay - data, weight = inputs new_attrs = dict(attrs) assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv3d's inputs" @@ -515,6 +506,30 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype): reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_convert_op_layout("nn.max_pool2d") +def convert_max_pool2d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for max_pool2d op. + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current pooling + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of one layout string + layout string defining our desired layout for input and output. + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + new_attrs = dict(attrs) + new_attrs["layout"] = str(desired_layouts[0]) + new_attrs["out_layout"] = str(desired_layouts[0]) + return relay.nn.max_pool2d(*inputs, **new_attrs) + + # max_pool3d reg.register_schedule("nn.max_pool3d", strategy.schedule_pool) reg.register_pattern("nn.max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -530,6 +545,30 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype): reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_convert_op_layout("nn.avg_pool2d") +def convert_avg_pool2d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for avg_pool2d op. + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current pooling + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of one layout string + layout string defining our desired layout for input and output. + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + new_attrs = dict(attrs) + new_attrs["layout"] = str(desired_layouts[0]) + new_attrs["out_layout"] = str(desired_layouts[0]) + return relay.nn.avg_pool2d(*inputs, **new_attrs) + + # avg_pool3d reg.register_schedule("nn.avg_pool3d", strategy.schedule_pool) reg.register_pattern("nn.avg_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -560,11 +599,59 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype): reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_convert_op_layout("nn.global_max_pool2d") +def convert_global_max_pool2d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for global_max_pool2d op. + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current pooling + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of one layout string + layout string defining our desired layout for input and output. + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + new_attrs = dict(attrs) + new_attrs["layout"] = str(desired_layouts[0]) + new_attrs["out_layout"] = str(desired_layouts[0]) + return relay.nn.global_max_pool2d(*inputs, **new_attrs) + + # global_avg_pool2d reg.register_schedule("nn.global_avg_pool2d", strategy.schedule_adaptive_pool) reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_convert_op_layout("nn.global_avg_pool2d") +def convert_global_avg_pool2d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for global_avg_pool2d op. + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current pooling + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of one layout string + layout string defining our desired layout for input and output. + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + new_attrs = dict(attrs) + new_attrs["layout"] = str(desired_layouts[0]) + new_attrs["out_layout"] = str(desired_layouts[0]) + return relay.nn.global_avg_pool2d(*inputs, **new_attrs) + + # adaptive_max_pool2d reg.register_schedule("nn.adaptive_max_pool2d", strategy.schedule_adaptive_pool) reg.register_pattern("nn.adaptive_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -796,9 +883,6 @@ def convert_deformable_conv2d(attrs, inputs, tinfos, desired_layouts): result : tvm.relay.Expr The transformed expr """ - # pylint: disable=import-outside-toplevel - from tvm import relay - data, offset, weight = inputs new_attrs = dict(attrs) for attr in new_attrs: diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 5a17db745b3e..1821ff17258a 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -748,7 +748,14 @@ def log_softmax(data, axis=-1): def max_pool1d( - data, pool_size=(1,), strides=(1,), dilation=(1,), padding=(0,), layout="NCW", ceil_mode=False + data, + pool_size=(1,), + strides=(1,), + dilation=(1,), + padding=(0,), + layout="NCW", + out_layout="", + ceil_mode=False, ): r"""1D maximum pooling operator. @@ -783,6 +790,9 @@ def max_pool1d( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -798,7 +808,9 @@ def max_pool1d( if isinstance(dilation, int): dilation = (dilation,) padding = get_pad_tuple1d(padding) - return _make.max_pool1d(data, pool_size, strides, dilation, padding, layout, ceil_mode) + return _make.max_pool1d( + data, pool_size, strides, dilation, padding, layout, out_layout, ceil_mode + ) def max_pool2d( @@ -808,6 +820,7 @@ def max_pool2d( dilation=(1, 1), padding=(0, 0), layout="NCHW", + out_layout="", ceil_mode=False, ): r"""2D maximum pooling operator. @@ -851,6 +864,9 @@ def max_pool2d( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -866,7 +882,9 @@ def max_pool2d( if isinstance(dilation, int): dilation = (dilation, dilation) padding = get_pad_tuple2d(padding) - return _make.max_pool2d(data, pool_size, strides, dilation, padding, layout, ceil_mode) + return _make.max_pool2d( + data, pool_size, strides, dilation, padding, layout, out_layout, ceil_mode + ) def max_pool3d( @@ -876,6 +894,7 @@ def max_pool3d( dilation=(1, 1, 1), padding=(0, 0, 0), layout="NCDHW", + out_layout="", ceil_mode=False, ): r"""3D maximum pooling operator. @@ -912,6 +931,9 @@ def max_pool3d( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -927,7 +949,9 @@ def max_pool3d( if isinstance(dilation, int): dilation = (dilation, dilation, dilation) padding = get_pad_tuple3d(padding) - return _make.max_pool3d(data, pool_size, strides, dilation, padding, layout, ceil_mode) + return _make.max_pool3d( + data, pool_size, strides, dilation, padding, layout, out_layout, ceil_mode + ) def avg_pool1d( @@ -937,6 +961,7 @@ def avg_pool1d( dilation=(1,), padding=(0,), layout="NCW", + out_layout="", ceil_mode=False, count_include_pad=False, ): @@ -973,6 +998,9 @@ def avg_pool1d( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -992,7 +1020,15 @@ def avg_pool1d( dilation = (dilation,) padding = get_pad_tuple1d(padding) return _make.avg_pool1d( - data, pool_size, strides, dilation, padding, layout, ceil_mode, count_include_pad + data, + pool_size, + strides, + dilation, + padding, + layout, + out_layout, + ceil_mode, + count_include_pad, ) @@ -1003,6 +1039,7 @@ def avg_pool2d( dilation=(1, 1), padding=(0, 0), layout="NCHW", + out_layout="", ceil_mode=False, count_include_pad=False, ): @@ -1048,6 +1085,9 @@ def avg_pool2d( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -1067,7 +1107,15 @@ def avg_pool2d( dilation = (dilation, dilation) padding = get_pad_tuple2d(padding) return _make.avg_pool2d( - data, pool_size, strides, dilation, padding, layout, ceil_mode, count_include_pad + data, + pool_size, + strides, + dilation, + padding, + layout, + out_layout, + ceil_mode, + count_include_pad, ) @@ -1078,6 +1126,7 @@ def avg_pool3d( dilation=(1, 1, 1), padding=(0, 0, 0), layout="NCDHW", + out_layout="", ceil_mode=False, count_include_pad=False, ): @@ -1115,6 +1164,9 @@ def avg_pool3d( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -1134,7 +1186,15 @@ def avg_pool3d( dilation = (dilation, dilation, dilation) padding = get_pad_tuple3d(padding) return _make.avg_pool3d( - data, pool_size, strides, dilation, padding, layout, ceil_mode, count_include_pad + data, + pool_size, + strides, + dilation, + padding, + layout, + out_layout, + ceil_mode, + count_include_pad, ) @@ -1145,6 +1205,7 @@ def max_pool2d_grad( strides=(1, 1), padding=(0, 0), layout="NCHW", + out_layout="", ceil_mode=False, ): r"""Gradient of 2D maximum pooling operator. @@ -1171,6 +1232,9 @@ def max_pool2d_grad( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -1179,7 +1243,9 @@ def max_pool2d_grad( result : tvm.relay.Expr The computed result. """ - return _make.max_pool2d_grad(out_grad, data, pool_size, strides, padding, layout, ceil_mode) + return _make.max_pool2d_grad( + out_grad, data, pool_size, strides, padding, layout, out_layout, ceil_mode + ) def avg_pool2d_grad( @@ -1189,6 +1255,7 @@ def avg_pool2d_grad( strides=(1, 1), padding=(0, 0), layout="NCHW", + out_layout="", ceil_mode=False, count_include_pad=False, ): @@ -1216,6 +1283,9 @@ def avg_pool2d_grad( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -1228,11 +1298,19 @@ def avg_pool2d_grad( The computed result. """ return _make.avg_pool2d_grad( - out_grad, data, pool_size, strides, padding, layout, ceil_mode, count_include_pad + out_grad, + data, + pool_size, + strides, + padding, + layout, + out_layout, + ceil_mode, + count_include_pad, ) -def global_max_pool2d(data, layout="NCHW"): +def global_max_pool2d(data, layout="NCHW", out_layout=""): r"""2D global maximum pooling operator. This operator takes data as input and does 2D max value calculation @@ -1258,15 +1336,18 @@ def global_max_pool2d(data, layout="NCHW"): layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + Returns ------- result : tvm.relay.Expr The computed result. """ - return _make.global_max_pool2d(data, layout) + return _make.global_max_pool2d(data, layout, out_layout) -def global_avg_pool2d(data, layout="NCHW"): +def global_avg_pool2d(data, layout="NCHW", out_layout=""): r"""2D global average pooling operator. This operator takes data as input and does 2D average value calculation @@ -1292,12 +1373,15 @@ def global_avg_pool2d(data, layout="NCHW"): layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + Returns ------- result : tvm.relay.Expr The computed result. """ - return _make.global_avg_pool2d(data, layout) + return _make.global_avg_pool2d(data, layout, out_layout) def upsampling( @@ -3114,7 +3198,7 @@ def space_to_depth(data, block_size, layout="NCHW"): return _make.space_to_depth(data, block_size, layout) -def adaptive_max_pool1d(data, output_size=None, layout="NCW"): +def adaptive_max_pool1d(data, output_size=None, layout="NCW", out_layout=""): r"""1D adaptive max pooling operator. This operator is experimental. This operator takes data as input and does 1D max value calculation @@ -3147,6 +3231,9 @@ def adaptive_max_pool1d(data, output_size=None, layout="NCW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr @@ -3155,10 +3242,10 @@ def adaptive_max_pool1d(data, output_size=None, layout="NCW"): output_size = [] or output_size if isinstance(output_size, int): output_size = [output_size] - return _make.adaptive_max_pool1d(data, output_size, layout) + return _make.adaptive_max_pool1d(data, output_size, layout, out_layout) -def adaptive_avg_pool1d(data, output_size=None, layout="NCW"): +def adaptive_avg_pool1d(data, output_size=None, layout="NCW", out_layout=""): r"""1D adaptive average pooling operator. This operator is experimental. This operator takes data as input and does 1D average value calculation @@ -3191,6 +3278,9 @@ def adaptive_avg_pool1d(data, output_size=None, layout="NCW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr @@ -3199,10 +3289,10 @@ def adaptive_avg_pool1d(data, output_size=None, layout="NCW"): output_size = [] or output_size if isinstance(output_size, int): output_size = [output_size] - return _make.adaptive_avg_pool1d(data, output_size, layout) + return _make.adaptive_avg_pool1d(data, output_size, layout, out_layout) -def adaptive_max_pool2d(data, output_size=None, layout="NCHW"): +def adaptive_max_pool2d(data, output_size=None, layout="NCHW", out_layout=""): r"""2D adaptive max pooling operator. This operator is experimental. This operator takes data as input and does 2D max value calculation @@ -3238,16 +3328,19 @@ def adaptive_max_pool2d(data, output_size=None, layout="NCHW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [] or output_size - return _make.adaptive_max_pool2d(data, output_size, layout) + return _make.adaptive_max_pool2d(data, output_size, layout, out_layout) -def adaptive_avg_pool2d(data, output_size=None, layout="NCHW"): +def adaptive_avg_pool2d(data, output_size=None, layout="NCHW", out_layout=""): r"""2D adaptive average pooling operator. This operator is experimental. This operator takes data as input and does 2D average value calculation @@ -3283,16 +3376,19 @@ def adaptive_avg_pool2d(data, output_size=None, layout="NCHW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [] or output_size - return _make.adaptive_avg_pool2d(data, output_size, layout) + return _make.adaptive_avg_pool2d(data, output_size, layout, out_layout) -def adaptive_max_pool3d(data, output_size=None, layout="NCDHW"): +def adaptive_max_pool3d(data, output_size=None, layout="NCDHW", out_layout=""): r"""3D adaptive max pooling operator. This operator is experimental. This operator takes data as input and does 3D max value calculation @@ -3327,16 +3423,19 @@ def adaptive_max_pool3d(data, output_size=None, layout="NCDHW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [] or output_size - return _make.adaptive_max_pool3d(data, output_size, layout) + return _make.adaptive_max_pool3d(data, output_size, layout, out_layout) -def adaptive_avg_pool3d(data, output_size=None, layout="NCDHW"): +def adaptive_avg_pool3d(data, output_size=None, layout="NCDHW", out_layout=""): r"""3D adaptive avg pooling operator. This operator is experimental. This operator takes data as input and does 3D avg value calculation @@ -3371,16 +3470,19 @@ def adaptive_avg_pool3d(data, output_size=None, layout="NCDHW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [] or output_size - return _make.adaptive_avg_pool3d(data, output_size, layout) + return _make.adaptive_avg_pool3d(data, output_size, layout, out_layout) -def global_max_pool1d(data, layout="NCW"): +def global_max_pool1d(data, layout="NCW", out_layout=""): r"""1D global maximum pooling operator. This operator takes data as input and does 1D max value calculation @@ -3403,16 +3505,19 @@ def global_max_pool1d(data, layout="NCW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [1] - return _make.adaptive_max_pool1d(data, output_size, layout) + return _make.adaptive_max_pool1d(data, output_size, layout, out_layout) -def global_avg_pool1d(data, layout="NCW"): +def global_avg_pool1d(data, layout="NCW", out_layout=""): r"""1D global average pooling operator. This operator takes data as input and does 1D average value calculation @@ -3436,16 +3541,19 @@ def global_avg_pool1d(data, layout="NCW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [1] - return _make.adaptive_avg_pool1d(data, output_size, layout) + return _make.adaptive_avg_pool1d(data, output_size, layout, out_layout) -def global_max_pool3d(data, layout="NCDHW"): +def global_max_pool3d(data, layout="NCDHW", out_layout=""): r"""3D global maximum pooling operator. This operator takes data as input and does 3D max value calculation @@ -3469,16 +3577,19 @@ def global_max_pool3d(data, layout="NCDHW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [1, 1, 1] - return _make.adaptive_max_pool3d(data, output_size, layout) + return _make.adaptive_max_pool3d(data, output_size, layout, out_layout) -def global_avg_pool3d(data, layout="NCDHW"): +def global_avg_pool3d(data, layout="NCDHW", out_layout=""): r"""3D global average pooling operator. This operator takes data as input and does 3D average value calculation @@ -3503,13 +3614,16 @@ def global_avg_pool3d(data, layout="NCDHW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [1, 1, 1] - return _make.adaptive_avg_pool3d(data, output_size, layout) + return _make.adaptive_avg_pool3d(data, output_size, layout, out_layout) def correlation( diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 8fd46817b817..dba40b2f6f34 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -564,6 +564,11 @@ class TopkAttrs(Attrs): """Attributes used in topk operators""" +@tvm._ffi.register_object("relay.attrs.SearchSortedAttrs") +class SearchSortedAttrs(Attrs): + """Attributes used in searchsorted operators""" + + @tvm._ffi.register_object("relay.attrs.TupleGetItemAttrs") class TupleGetItemAttrs(Attrs): """Attributes used in tuple item access operators""" diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index e8731a0d6954..06dfc87038fe 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -130,9 +130,9 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): elif layout == "NHWC": if "SMLAD" in isa and kernel_layout == "HWOI": strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.conv2d_direct_simd), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_direct_simd), - name="conv2d_direct_simd.micro_dev", + wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_direct_simd), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_direct_simd), + name="conv2d_nhwc_direct_simd.micro_dev", ) elif kernel_layout == "HWIO": is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm() diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index da7cbd5cec10..5f24dbda9d35 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1022,6 +1022,18 @@ def topk_strategy_cuda(attrs, inputs, out_type, target): return strategy +@searchsorted_strategy.register(["cuda", "gpu"]) +def searchsorted_strategy_cuda(attrs, inputs, out_type, target): + """searchsorted cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_searchsorted(topi.cuda.searchsorted), + wrap_topi_schedule(topi.cuda.schedule_extern), + name="searchsorted.cuda", + ) + return strategy + + @multibox_prior_strategy.register(["cuda", "gpu"]) def multibox_prior_strategy_cuda(attrs, inputs, out_type, target): """multibox_prior cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index d021b5d9d84d..777f17ba6084 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1002,6 +1002,31 @@ def topk_strategy(attrs, inputs, out_type, target): return strategy +# searchsorted +def wrap_compute_searchsorted(topi_compute): + """Wrap searchsorted compute""" + + def _compute_searchsorted(attrs, inputs, out_type): + right = attrs.right + dtype = attrs.dtype + return [topi_compute(inputs[0], inputs[1], right, dtype)] + + return _compute_searchsorted + + +# searchsorted_strategy +@override_native_generic_func("searchsorted_strategy") +def searchsorted_strategy(attrs, inputs, out_type, target): + """searchsorted generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_searchsorted(topi.searchsorted), + wrap_topi_schedule(topi.generic.schedule_extern), + name="searchsorted.generic", + ) + return strategy + + # multibox_prior def wrap_compute_multibox_prior(topi_compute): """Wrap multibox_prior compute""" diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index e47928919ce1..e615bbf21b86 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1070,7 +1070,7 @@ def fixed_point_multiply(data, multiplier, shift): The input tensor. multiplier : int The integer multiplier of the fixed point constant. - a_max : float + shift : int The integer shift of the fixed point constant. Returns diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 3b4d97576cd7..7f4724db22b2 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -51,7 +51,7 @@ def kind2str(kind): def _forward_op(ref_call, args): """forward the operator of ref_call with provided arguments""" - return _expr.Call(ref_call.op, args, ref_call.attrs, ref_call.type_args) + return _expr.Call(ref_call.op, args, ref_call.attrs, ref_call.type_args, ref_call.span) @tvm._ffi.register_object("relay.quantize.QConfig") diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index b9d6806306f4..50f473aea1f2 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -24,7 +24,7 @@ import tvm from tvm import relay from tvm.relay.adt import Pattern -from tvm.relay.backend import compile_engine +from tvm.relay.backend import te_compiler from tvm.relay.expr import Expr, GlobalVar, Var from tvm.relay.function import Function from tvm.relay.expr_functor import ExprFunctor @@ -61,7 +61,7 @@ def __init__(self, mod, target) -> None: super().__init__() self.mod = mod self.tgt = target - self.engine = compile_engine.get() + self.tec = te_compiler.get() self.fun_no = 0 self.var_no = 0 self.var_map = {} @@ -153,7 +153,10 @@ def parse_name(self, name: str): def parse_numpy_array(self, arr): """Given a Numpy array, produces an appropriate Python array or numerical literal representing its contents.""" - parse_single = lambda i: NameConstant(i) if isinstance(i, bool) else Num(i) + + def parse_single(i): + return NameConstant(i) if isinstance(i, bool) else Num(i) + if arr.ndim == 0: return parse_single(arr.item()) if arr.ndim == 1: @@ -240,11 +243,11 @@ def create_op_call(self, op: Function, relay_args, py_args): the generated Python code.""" # compile the function and register globally - cc_key = compile_engine.CCacheKey(op, self.tgt) + cc_key = te_compiler.CCacheKey(op, self.tgt) func_hash = tvm.ir.structural_hash(op) op_name = "_lowered_op_{}".format(func_hash) if not tvm.get_global_func(op_name, allow_missing=True): - jitted = self.engine.jit(cc_key, self.tgt) + jitted = self.tec.jit(cc_key, self.tgt) tvm.register_func(op_name, jitted) def convert_input(py_input, arg_type): diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index e5ec73db51b9..c3b0056eb591 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -379,7 +379,7 @@ def _update_tracker(self, period_update=False): if need_update_info: keylist = "[" + ",".join(self._key_set) + "]" - cinfo = {"key": "server:proxy" + keylist} + cinfo = {"key": "server:proxy" + keylist, "addr": [None, self._listen_port]} base.sendjson(self._tracker_conn, [TrackerCode.UPDATE_INFO, cinfo]) assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS self._tracker_pending_puts = [] diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index 74c1f7ac07aa..5a576a705e8a 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -337,9 +337,10 @@ def request(self, key, user, priority, callback): def close(self, conn): self._connections.remove(conn) if "key" in conn._info: - key = conn._info["key"].split(":")[1] # 'server:rasp3b' -> 'rasp3b' for value in conn.put_values: - self._scheduler_map[key].remove(value) + _, _, _, key = value + rpc_key = key.split(":")[0] + self._scheduler_map[rpc_key].remove(value) def stop(self): """Safely stop tracker.""" diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py index b91fe727698b..7d40a81e498a 100644 --- a/python/tvm/runtime/profiling/__init__.py +++ b/python/tvm/runtime/profiling/__init__.py @@ -47,6 +47,35 @@ def csv(self): """ return _ffi_api.AsCSV(self) + def table(self, sort=True, aggregate=True, col_sums=True): + """Generate a human-readable table + + Parameters + ---------- + sort : bool + + If aggregate is true, whether to sort call frames by + descending duration. If aggregate is False, whether to + sort frames by order of appearancei n the program. + + aggregate : bool + + Whether to join multiple calls to the same op into a + single line. + + col_sums : bool + + Whether to include the sum of each column. + + Returns + ------- + table : str + + A human-readable table + + """ + return _ffi_api.AsTable(self, sort, aggregate, col_sums) + def json(self): """Convert this profiling report into JSON format. diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 75566cf6e2c5..56d080857a7d 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -22,8 +22,10 @@ import tvm from tvm.ir import Span +from tvm.ir.expr import Range from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion from tvm.runtime import Object +from tvm.tir.expr import IterVar from .tir.node import BufferSlice @@ -41,10 +43,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(a, (16, 16), "float32") for i, j, k in T.grid(16, 16, 16): - with T.block([16, 16, T.reduce_axis(16)], "matmul") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) # iter_bindings = {vj: i, vj: j, vk: k} + with T.block("matmul"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j) + vk = T.axis.R(16, k) # iter_bindings = {vj: i, vj: j, vk: k} T.where(True) # predicate of the block_realize @@ -72,8 +74,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: """List[Buffer]: list of T.alloc_buffer statements in the block signature""" match_buffers: List[MatchBufferRegion] = [] """List[MatchBufferRegion]: list of T.match_buffer statements in the block signature""" - iter_bindings: Mapping[Var, PrimExpr] = {} - """Mapping[Var, PrimExpr]: map of block iter var to its values""" + iter_values: List[PrimExpr] = [] + """List[PrimExpr]: list of binding values for iter vars""" + iter_vars: List[IterVar] = [] + """List[PrimExpr]: list of iter vars in the block""" reads: Optional[List[BufferSlice]] = None """Optional[List[BufferSlice]]: list of T.reads statements in the block signature, None for not-visited""" @@ -91,7 +95,8 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: def __init__(self): self.alloc_buffers = [] self.match_buffers = [] - self.iter_bindings = {} + self.iter_values = [] + self.iter_vars = [] self.reads = None self.writes = None self.annotations = None @@ -112,8 +117,8 @@ class ContextMaintainer: """List[List[synr.ast.Node]]: The ast nodes insides the current scope""" block_info_stack: List[BlockInfo] = [] """List[BlockInfo]: The block info for the current block scope""" - loop_stack: List[List[Var]] = [] - """List[List[Var]]: List of loop vars inside the current block scope""" + loop_stack: Dict[Var, Range] = {} + """Dict[Var, Range]: The dict from loop var to its domain outside the block""" symbols: List[Dict[str, Union[Var, Buffer]]] = [] """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope""" @@ -137,7 +142,7 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No # scope context self.node_stack = [] self.block_info_stack = [] - self.loop_stack = [] + self.loop_stack = {} self.symbols = [] # function context self.func_params = [] @@ -183,8 +188,6 @@ def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None): The synr AST nodes in new scope """ self.enter_scope(nodes) - # Create a new loop stack for the new block - self.loop_stack.append([]) # Create a new BlockInfo for the new block self.block_info_stack.append(BlockInfo()) @@ -196,8 +199,6 @@ def exit_scope(self): def exit_block_scope(self): """Pop the inner most block scope, the function will call `exit_scope` implicitly""" self.exit_scope() - # Pop loop stack - self.loop_stack.pop() # Pop block_info self.block_info_stack.pop() diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index d5e79e8676c1..080aa0476bec 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -377,12 +377,13 @@ def A(): """ if len(node.assignments) == 1: if not ( - isinstance(node.assignments[0].lhs, ast.Var) - and node.assignments[0].lhs.id.name == "__tvm_meta__" + len(node.assignments[0].lhs) == 1 + and isinstance(node.assignments[0].lhs[0], ast.Var) + and node.assignments[0].lhs[0].id.name == "__tvm_meta__" ): self.report_error( "The only top level assignments allowed are `__tvm_meta__ = ...`", - node.assignments[0].lhs.span, + node.assignments[0].span, ) self.init_meta( MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context) @@ -489,6 +490,31 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: self.context.exit_scope() return func + def transform_Lambda(self, node): + """Lambda visitor + + Return an array of input parameters and the transformed lambda body. + """ + + self.context.enter_scope(nodes=[node.body]) + + # add parameters of the lambda + arg_vars = [] + for arg in node.params: + arg_var = tvm.te.var(arg.name) + arg_vars.append(arg_var) + self.context.update_symbol(arg.name, arg_var, node) + + # the body of a lambda must be an expr + if not isinstance(node.body, ast.Expr): + self.report_error("The body of a lambda must be an expression", node.span) + + # transform the body of the lambda + body = self.transform(node.body) + + self.context.exit_scope() + return arg_vars, body + def transform_Assign(self, node): """Assign visitor AST abstract grammar: @@ -526,18 +552,19 @@ def transform_Assign(self, node): return self.parse_body(node) else: value = self.transform(node.rhs) - if not isinstance(node.lhs, ast.Var): + if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var): # This is a little confusing because it only is true when # we have taken this branch. We might need to clarify what # exectly is allowed in Assignments in tvmscript. self.report_error( "Left hand side of assignment must be an unqualified variable", - node.lhs.span, + node.span, ) + ast_var = node.lhs[0] var = tvm.te.var( - node.lhs.id.name, - self.parse_type(node.ty, node.lhs), - span=tvm_span_from_synr(node.lhs.span), + ast_var.id.name, + self.parse_type(node.ty, ast_var), + span=tvm_span_from_synr(ast_var.span), ) self.context.update_symbol(var.name, var, node) body = self.parse_body(node) @@ -596,7 +623,7 @@ def transform_For(self, node): For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) By now 1 pattern of For is supported: 1. for scope handler - for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/tir.range()/ + for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/ T.grid()/T.thread_binding() """ @@ -892,9 +919,20 @@ def transform_Attr(self, node): namespace. """ - if isinstance(node.object, ast.Var): - if self.match_tir_namespace(node.object.id.name): - func_name = "tir." + node.field.name + def get_full_attr_name(node: ast.Attr) -> str: + reverse_field_names = [node.field.name] + while isinstance(node.object, ast.Attr): + node = node.object + reverse_field_names.append(node.field.name) + if isinstance(node.object, ast.Var): + reverse_field_names.append(node.object.id.name) + return ".".join(reversed(reverse_field_names)) + + if isinstance(node.object, (ast.Var, ast.Attr)): + full_attr_name = get_full_attr_name(node) + attr_object, fields = full_attr_name.split(".", maxsplit=1) + if self.match_tir_namespace(attr_object): + func_name = "tir." + fields res = Registry.lookup(func_name) if res is not None: return res @@ -903,9 +941,7 @@ def transform_Attr(self, node): except TVMError as e: # Check if we got an attribute error if e.args[0].find("AttributeError"): - self.report_error( - f"Unregistered function `tir.{node.field.name}`.", node.field.span - ) + self.report_error(f"Unregistered function `tir.{fields}`.", node.span) else: raise e diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index 4d7fe80b28b1..2e800355bef6 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -16,6 +16,7 @@ # under the License. """TVM Script Parser Intrinsic Classes""" # pylint: disable=redefined-builtin, relative-beyond-top-level +import builtins from typing import List, Any import tvm.tir @@ -211,3 +212,20 @@ def store(var, index, value, predicate=True, span=None): return tvm.tir.Store(var, value, index, predicate, span) super().__init__(store, stmt=True) + + +@register +def comm_reducer(lambda_io, identities, span): + """Create a CommReducer from lambda inputs/outputs and the identities""" + lambda_input = lambda_io[0] + lambda_output = lambda_io[1] + + num_args = len(lambda_input) + num_arg_per_group = num_args // 2 + x = [lambda_input[i] for i in builtins.range(0, num_arg_per_group)] + y = [lambda_input[i] for i in builtins.range(num_arg_per_group, num_args)] + + if not isinstance(lambda_output, tuple): + lambda_output = (lambda_output,) + + return tvm.tir.CommReducer(x, y, lambda_output, identities, span) diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 487a71d4f077..4750ad7626e2 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -134,12 +134,14 @@ def enter_scope( if isinstance(node, synr.ast.With): vars = WithScopeHandler.get_optional_vars(node, context) if len(vars) != 1: - context.report_error("Unexpected number of vars", node.span) + context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span) name = vars[0].id.name var_span = vars[0].id.span elif isinstance(node, synr.ast.Assign): - name = node.lhs.id.name - var_span = node.lhs.id.span + if len(node.lhs) != 1: + context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span) + name = node.lhs[0].id.name + var_span = node.lhs[0].id.span else: raise Exception("Internal Bug") @@ -247,42 +249,16 @@ def let(var, value, span): @register class Block(WithScopeHandler): - """With scope handler T.block(extents, name) as iter_vars""" + """With scope handler T.block(name)""" def __init__(self): - def block(axes=None, name_hint: str = "", span: Optional[Span] = None): + def block(name_hint: str = "", span: Optional[Span] = None): assert ( self.node and self.context and self.body ), "call 'exit_scope' before 'enter_scope'" block_info = self.context.block_info_stack[-1] - if axes is None: - axes = [] - if len(axes) != len(self.block_vars): - self.context.report_error( - "Inconsistent number of block vars, " - + f"there are {len(axes)} axes but {len(self.block_vars)} block vars. " - + "The number of block vars should match the number of axes.", - self.node.span, - ) - block_iters: List[IterVar] = [] - for i, axis in enumerate(axes): - axis = tvm.runtime.convert(axis) - if isinstance(axis, tvm.tir.PrimExpr): - block_var_dom = Range.from_min_extent(0, axis) - block_iters.append(IterVar(block_var_dom, self.block_vars[i], 0)) - elif isinstance(axis, Range): - block_iters.append(IterVar(axis, self.block_vars[i], 0)) - elif isinstance(axis, IterVar): - block_iters.append(IterVar(axis.dom, self.block_vars[i], axis.iter_type)) - else: - self.context.report_error( - "Invalid argument of T.block(), " - + f"expected PrimExpr, Range or IterVar, but got {type(axis)}", - self.node.span, - ) # create block read/write regions - reads: List[BufferRegion] = ( [buffer_slice_to_region(read) for read in block_info.reads] if block_info.reads @@ -301,7 +277,7 @@ def block(axes=None, name_hint: str = "", span: Optional[Span] = None): if region_detect_mask != 0: annotations["tir.script_parsing_detect_access"] = region_detect_mask inner = tvm.tir.Block( - block_iters, + block_info.iter_vars, reads, writes, name_hint, @@ -312,35 +288,13 @@ def block(axes=None, name_hint: str = "", span: Optional[Span] = None): annotations, span, ) - # create block var iter binding - values: List[PrimExpr] - if not block_info.iter_bindings: - values = self.context.loop_stack[-2].copy() - if len(block_iters) == 0: - # It is an opaque block without any bindings - values = [] - elif len(values) == 0: - values = [tvm.tir.const(float("nan"), dtype="float32")] * len(block_iters) - elif len(values) != len(block_iters): - self.context.report_error( - "Number of block iter var and outer loop nesting mismatch, " - + f"{len(block_iters)} block iter vars but {len(values)} loops", - self.node.span, - ) - else: - for block_var in self.block_vars: - if block_var not in block_info.iter_bindings: - self.context.report_error( - "Missing block iter var binding for " + block_var.name, - self.node.span, - ) - values = [block_info.iter_bindings[block_var] for block_var in self.block_vars] + assert len(block_info.iter_vars) == len(block_info.iter_values) predicate = ( tvm.tir.const(True, "bool") if block_info.predicate is None else block_info.predicate ) - body = tvm.tir.BlockRealize(values, predicate, inner, span) + body = tvm.tir.BlockRealize(block_info.iter_values, predicate, inner, span) return body super().__init__(func=block, concise_scope=False, def_symbol=True) @@ -358,10 +312,13 @@ def enter_scope( node, synr.ast.With ), f"BlockScopeHandler expected to work on synr.ast.With but got {type(node)}" - vars = WithScopeHandler.get_optional_vars(node, context) - self.block_vars = [tvm.te.var(var.id.name) for var in vars] - for block_var in self.block_vars: - context.update_symbol(block_var.name, block_var, node) + optional_vars = [var.id.name for var in WithScopeHandler.get_optional_vars(node, context)] + if optional_vars: + context.report_error( + f"Block expected no optional_vars (e.g., `x` in `with block() as x`), " + f"but got {optional_vars}", + node.span, + ) @register @@ -378,12 +335,38 @@ def init(span: Span = None): super().__init__(func=init, concise_scope=False, def_symbol=True) +class LoopInfo: + """Helper class for loop information""" + + loop_var: Var + begin: PrimExpr + extent: PrimExpr + kind: ForKind + thread_binding: Optional[str] + annotations: Optional[Mapping[str, Object]] + + def __init__( + self, + begin: PrimExpr, + extent: PrimExpr, + kind: ForKind, + thread_binding: Optional[str] = None, + annotations: Optional[Mapping[str, Object]] = None, + ) -> None: + self.begin = begin + self.extent = extent + self.kind = kind + self.thread_binding = thread_binding + self.annotations = annotations + + class ForScopeHandler(ScopeHandler): """Base class for all for scope handlers""" def __init__(self, func): super().__init__(func) - self.loop_vars: Optional[List[Var]] = None + self.loop_vars: List[Var] = [] + self.loop_info: List[LoopInfo] = [] def enter_scope( self, @@ -415,12 +398,23 @@ def enter_scope( span, ) + self.node = node + self.context = context + # generate loop vars self.loop_vars = [ tvm.te.var(name, dtype="int32", span=span) for name, span in zip(loop_var_names, spans) ] - for loop_var in self.loop_vars: + # collect loop infos by calling self.func + call_with_error_reporting(context.report_error, span, self.func, *arg_list) + if len(self.loop_vars) != len(self.loop_info): + self.context.report_error( + f"Inconsistent number of vars and loops, got {len(self.loop_vars)} " + + f"vs {len(self.loop_info)}", + self.node.span, + ) + for loop_var, loop_info in zip(self.loop_vars, self.loop_info): context.update_symbol(loop_var.name, loop_var, node) - context.loop_stack[-1].append(loop_var) + context.loop_stack[loop_var] = Range.from_min_extent(loop_info.begin, loop_info.extent) def exit_scope( self, @@ -430,19 +424,34 @@ def exit_scope( span: synr.ast.Span, ): assert self.loop_vars, "call 'exit_scope' before 'enter_scope'" - for _ in self.loop_vars: - context.loop_stack[-1].pop() - return super().exit_scope(node, context, arg_list, span) + for loop_var in self.loop_vars: + context.loop_stack.pop(loop_var) + # Use assert here since we have check it in `enter_scope` + assert len(self.loop_vars) == len(self.loop_info) + + body = self.body + for var, info in zip(reversed(self.loop_vars), reversed(self.loop_info)): + body = tvm.tir.For( + var, + info.begin, + info.extent, + info.kind, + body, + info.thread_binding, + info.annotations, + span=tvm_span_from_synr(span), + ) - def create_loop( + return body + + def create_loop_info( self, begin: PrimExpr, end: PrimExpr, kind: ForKind, thread_binding: Optional[str] = None, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, - ) -> tvm.tir.For: + ) -> None: """ Helper function for creating For in TVM Script parser. @@ -471,30 +480,16 @@ def create_loop( for : For The constructed For. """ - assert ( - self.loop_vars and self.context and self.node - ), "call 'exit_scope' before 'enter_scope'" - if len(self.loop_vars) != 1: - self.context.report_error( - f"Expected exactly one loop var, but got {self.loop_vars}", self.node.span - ) + assert self.context and self.node, "call 'exit_scope' before 'enter_scope'" extent = end if begin == 0 else self.context.analyzer.simplify(end - begin) - annos: Mapping[str, Object] = {} + self.annotations: Mapping[str, Object] = {} if annotations is not None: - annos = { + self.annotations = { key: tvm.tir.StringImm(val) if isinstance(val, str) else val for key, val in annotations.items() } - return tvm.tir.For( - self.loop_vars[0], - begin, - extent, - kind, - self.body, - thread_binding=thread_binding, - annotations=annos, - span=span, - ) + + self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations)) @register @@ -506,9 +501,8 @@ def serial( begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - return self.create_loop(begin, end, ForKind.SERIAL, annotations=annotations, span=span) + self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations) super().__init__(serial) @@ -522,11 +516,8 @@ def parallel( begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - return self.create_loop( - begin, end, ForKind.PARALLEL, annotations=annotations, span=span - ) + self.create_loop_info(begin, end, ForKind.PARALLEL, annotations=annotations) super().__init__(parallel) @@ -540,11 +531,8 @@ def vectorized( begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - return self.create_loop( - begin, end, ForKind.VECTORIZED, annotations=annotations, span=span - ) + self.create_loop_info(begin, end, ForKind.VECTORIZED, annotations=annotations) super().__init__(vectorized) @@ -558,11 +546,8 @@ def unroll( begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - return self.create_loop( - begin, end, ForKind.UNROLLED, annotations=annotations, span=span - ) + self.create_loop_info(begin, end, ForKind.UNROLLED, annotations=annotations) super().__init__(unroll) @@ -577,16 +562,14 @@ def thread_binding( end: PrimExpr, thread: str, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread, span=span) - return self.create_loop( + thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread) + self.create_loop_info( begin, end, ForKind.THREAD_BINDING, thread_binding=thread_iter_var, annotations=annotations, - span=span, ) super().__init__(thread_binding) @@ -603,12 +586,11 @@ def for_range( begin: PrimExpr, end: PrimExpr = None, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): if end is None: end = begin begin = 0 - return self.create_loop(begin, end, ForKind.SERIAL, annotations=annotations, span=span) + self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations) super().__init__(for_range) @@ -621,19 +603,8 @@ class Grid(ForScopeHandler): """For scope handler T.grid(extents)""" def __init__(self): - def grid(*extents: List[PrimExpr], span: Span): - assert ( - self.node and self.context and self.loop_vars - ), "call 'exit_scope' before 'enter_scope'" - if len(self.loop_vars) != len(extents): - self.context.report_error( - "Inconsistent number of loop vars and extents, " - + f"got {len(self.loop_vars)} vs {len(extents)}", - self.node.span, - ) - body = self.body - for loop_var, extent in zip(reversed(self.loop_vars), reversed(extents)): - body = tvm.tir.For(loop_var, 0, extent, ForKind.SERIAL, body, span=span) - return body + def grid(*extents: List[PrimExpr]): + for extent in extents: + self.create_loop_info(0, extent, ForKind.SERIAL) super().__init__(grid) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 69cf15f493de..de212352f3e4 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -21,17 +21,18 @@ import synr from synr import ast +from tvm.ir.expr import PrimExpr, Range import tvm.tir from tvm.runtime import Object from tvm import te from tvm.ir import Span -from tvm.tir import IntImm +from tvm.tir import IntImm, IterVar from .node import BufferSlice from .utils import buffer_slice_to_region -from ..context_maintainer import ContextMaintainer +from ..context_maintainer import BlockInfo, ContextMaintainer from ..registry import register from ..utils import ( get_param_list, @@ -132,9 +133,10 @@ def match_buffer( buffer_type="default", span=None, ): - if not isinstance(self.node, ast.Assign): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: self.context.report_error( - "match_buffer must be assigned to a buffer, e.g. A = match_buffer(...)", + "`match_buffer` must be assigned to a single buffer, " + "e.g. A = match_buffer(...)", self.node.span, ) if strides is None: @@ -143,10 +145,11 @@ def match_buffer( offset_factor = convert_to_int( offset_factor, "offset_factor", self.context.report_error, self.node.span ) + buffer_name: str = self.node.lhs[0].id.name buffer = tvm.tir.decl_buffer( shape, dtype, - self.node.lhs.id.name, + buffer_name, data, strides, elem_offset, @@ -173,7 +176,7 @@ def match_buffer( + str(type(param)), self.node.rhs.params[0].span, ) - self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) + self.context.update_symbol(buffer_name, buffer, self.node) super().__init__(match_buffer, def_symbol=True) @@ -201,9 +204,9 @@ def buffer_decl( buffer_type="default", span=None, ): - if not isinstance(self.node, ast.Assign): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: self.context.report_error( - "buffer_decl must be assigned to a buffer, e.g. A = buffer_decl(...)", + "`buffer_decl` must be assigned to a single buffer, e.g. A = buffer_decl(...)", self.node.span, ) @@ -213,10 +216,11 @@ def buffer_decl( offset_factor = convert_to_int( offset_factor, "offset_factor", self.context.report_error, self.node.span ) + buffer_name: str = self.node.lhs[0].id.name buffer = tvm.tir.decl_buffer( shape, dtype, - self.node.lhs.id.name, + buffer_name, data, strides, elem_offset, @@ -226,7 +230,7 @@ def buffer_decl( buffer_type, span=span, ) - self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) + self.context.update_symbol(buffer_name, buffer, self.node) return buffer super().__init__(buffer_decl, def_symbol=True) @@ -257,9 +261,10 @@ def alloc_buffer( buffer_type="default", span=None, ): - if not isinstance(self.node, ast.Assign): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: self.context.report_error( - "alloc_buffer must be assigned to a buffer, e.g. A = alloc_buffer(...)", + "`alloc_buffer` must be assigned to a single buffer, " + "e.g. A = alloc_buffer(...)", self.node.span, ) @@ -269,10 +274,11 @@ def alloc_buffer( offset_factor = convert_to_int( offset_factor, "offset_factor", self.context.report_error, self.node.span ) + buffer_name: str = self.node.lhs[0].id.name buffer = tvm.tir.decl_buffer( shape, dtype, - self.node.lhs.id.name, + buffer_name, data, strides, elem_offset, @@ -283,32 +289,11 @@ def alloc_buffer( span=span, ) self.context.current_block_scope().alloc_buffers.append(buffer) - self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) + self.context.update_symbol(buffer_name, buffer, self.node) super().__init__(alloc_buffer, def_symbol=True) -@register -class BlockVarBind(SpecialStmt): - """Special function bind(block_iter, binding_value) - - Example - ------- - .. code-block:: python - - T.bind(vx, i) - """ - - def __init__(self): - def bind(iter_var, values, span=None): - block_scope = self.context.current_block_scope() - if iter_var in block_scope.iter_bindings: - self.context.report_error("Duplicate iter_var bindings of " + str(iter_var), span) - block_scope.iter_bindings[iter_var] = values - - super().__init__(bind, def_symbol=False) - - @register class BlockReads(SpecialStmt): """Special function reads([read_buffer_regions]) @@ -412,6 +397,315 @@ def block_attr(attrs: Mapping[str, Object], span: Span = None): super().__init__(block_attr, def_symbol=False) +class BlockAxis(SpecialStmt): + """Special stmt for defining a spatial block axis + axis.S(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.S(128, i * 4 + j) + """ + + def axis( + self, + var_name: str, + dom: Union[PrimExpr, Range], + value: PrimExpr, + iter_type: int, + span: Optional[Span] = None, + ) -> None: + """ + Helper function for creating block axis + + Parameters + ---------- + var_name : str + The name_hint of var + + dom : Union[PrimExpr, Range] + The iter domain. + + value : PrimExpr + The binding value + + iter_type : int + The iteration type. + + span : Optional[Span] + The location of this for in the source code. + """ + assert self.context, "call 'exit_scope' before 'enter_scope'" + block_scope: BlockInfo = self.context.current_block_scope() + if var_name in [iter_var.var.name for iter_var in block_scope.iter_vars]: + self.context.report_error("Duplicate block axis " + var_name, self.node.span) + + block_var = tvm.tir.Var(var_name, dtype="int32") + dom = tvm.runtime.convert(dom) + if isinstance(dom, PrimExpr): + dom = tvm.ir.Range.from_min_extent(0, dom) + elif not isinstance(dom, tvm.ir.Range): + self.context.report_error( + f"Block axis domain expected PrimExpr or Range, but got {type(value)}", + self.node.span, + ) + value = tvm.runtime.convert(value) + if not isinstance(value, PrimExpr): + self.context.report_error( + f"Block axis value expected PrimExpr, but got {type(value)}", + self.node.span, + ) + iter_var = tvm.tir.IterVar(dom, block_var, iter_type) + block_scope.iter_vars.append(iter_var) + block_scope.iter_values.append(value) + self.context.update_symbol(var_name, block_var, self.node) + + +@register +class BlockAxisSpatial(BlockAxis): + """Special stmt for defining a spatial block axis + axis.spatial(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.spatial(128, k) + """ + + def __init__(self): + def axis_spatial( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.spatial` must be assigned to a var, e.g. vi = axis.spatial(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar) + + super().__init__(axis_spatial, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.spatial", get_param_list(self.func) + + +@register +class BlockAxisS(BlockAxis): + """The sugar special stmt for defining a spatial block axis + axis.S(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.S(128, k) + """ + + def __init__(self): + def axis_spatial( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.S` must be assigned to a var, e.g. vi = axis.S(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar) + + super().__init__(axis_spatial, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.S", get_param_list(self.func) + + +@register +class BlockAxisReduce(BlockAxis): + """Special stmt for defining a reduce block axis + axis.reduce(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.reduce(128, k) + """ + + def __init__(self): + def axis_reduce( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.reduce` must be assigned` to a var, e.g. vi = axis.reduce(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce) + + super().__init__(axis_reduce, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.reduce", get_param_list(self.func) + + +@register +class BlockAxisR(BlockAxis): + """The sugar special stmt for defining a reduce block axis + axis.R(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.R(128, k) + """ + + def __init__(self): + def axis_reduce( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.R` must be assigned to a var, e.g. vi = axis.R(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce) + + super().__init__(axis_reduce, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.R", get_param_list(self.func) + + +@register +class BlockAxisScan(BlockAxis): + """Special stmt for defining a ordered block axis + axis.scan(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.scan(128, k) + """ + + def __init__(self): + def axis_scan( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.scan` must be assigned to a var, e.g. vi = axis.scan(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.Ordered) + + super().__init__(axis_scan, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.scan", get_param_list(self.func) + + +@register +class BlockAxisOpaque(BlockAxis): + """Special stmt for defining a opaque block axis + axis.opaque(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.opaque(128, k) + """ + + def __init__(self): + def axis_opaque( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.opaque` must be assigned to a var, e.g. vi = axis.opaque(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DimInfo) + + super().__init__(axis_opaque, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.opaque", get_param_list(self.func) + + +@register +class BlockAxisRemap(BlockAxis): + """Special stmt for remapping loops vars to block axes. + axis.remap(iter_type, iter_value) + + Note + ---- + Iter_type is a string consisting of 'S' and 'R', where 'S' means + for spatial and 'R' means for reduce. + + Example + ------- + .. code-block:: python + + vi, vj = T.axis.remap("SS", [i, j]) + """ + + def __init__(self): + def axis_remap(iter_types: str, loop_vars: List[tvm.tir.expr.Var], span: Span = None): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) >= 1: + self.context.report_error( + "`axis.remap` must be assigned to one or more vars, " + "e.g. vi, vj = axis.remap(...)", + self.node.span, + ) + var_num: int = len(self.node.lhs) + if var_num != len(iter_types): + self.context.report_error( + f"`iter_type` expected {var_num} charactor(s), " + f"but got {len(iter_types)}: {iter_types}", + span, + ) + if var_num != len(loop_vars): + self.context.report_error( + f"`iter_type` expected {var_num} loop var(s), " + f"but got {len(loop_vars)}: {loop_vars}", + span, + ) + for var, iter_ty, loop_var in zip(self.node.lhs, iter_types, loop_vars): + iter_type: int + if iter_ty == "S": + iter_type = IterVar.DataPar + elif iter_ty == "R": + iter_type = IterVar.CommReduce + else: + self.context.report_error( + f'`iter_type` only expected "S" (for spatial) or "R" (for reduce), ' + f'but got "{iter_ty}"', + span, + ) + + if not isinstance(loop_var, tvm.tir.expr.Var): + self.context.report_error( + f"Values of `axis.remap` expected single loop var, but got {loop_var}", + loop_var.span, + ) + loops = self.context.loop_stack + if loop_var not in loops: + self.context.report_error( + f"Cannot find loop var {loop_var} in loop nesting.", + span, + ) + self.axis(var.id.name, loops[loop_var], loop_var, iter_type) + + super().__init__(axis_remap, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.remap", get_param_list(self.func) + + @register class BlockPredicate(SpecialStmt): """Special function where(predicate) @@ -449,7 +743,12 @@ def var(dtype, span): assert isinstance( self.node, ast.Assign ), f"VarDef expected ast.Assign but got {type(self.node)}" - v = te.var(self.node.lhs.id.name, dtype, span=span) + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"VarDef expected assign to only one var, but got {names}", span + ) + v = te.var(names[0], dtype, span=span) self.context.update_symbol(v.name, v, self.node) super().__init__(var, def_symbol=True) @@ -464,8 +763,13 @@ def buffer_var(dtype, storage_scope, span): assert isinstance( self.node, ast.Assign ), f"BufferVarDef expected ast.Assign but got {type(self.node)}" + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"VarDef expected assign to only one var, but got {names}", span + ) ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) - v = te.var(self.node.lhs.id.name, ptr_type, span=span) + v = te.var(names[0], ptr_type, span=span) self.context.update_symbol(v.name, v, self.node) super().__init__(buffer_var, def_symbol=True) @@ -480,7 +784,12 @@ def env_thread(env_name, span): assert isinstance( self.node, ast.Assign ), f"EnvThread expected ast.Assign but got {type(self.node)}" - v = te.var(self.node.lhs.id.name, span=span) + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"VarDef expected assign to only one var, but got {names}", span + ) + v = te.var(names[0], span=span) self.context.func_var_env_dict[v] = env_name self.context.update_symbol(v.name, v, self.node) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 4e5826f5b2a2..9af09296e9cc 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -31,6 +31,11 @@ class TargetKind(Object): """Kind of a compilation target""" + @property + def options(self): + """Returns the dict of available option names and types""" + return dict(_ffi_api.ListTargetKindOptions(self)) + @tvm._ffi.register_object class Target(Object): diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 681e322b2082..cb0305d49e4a 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -467,10 +467,12 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: - with T.init(): - C[i, j] = 0.0 - C[i, j] += A[i, k] * B[j, k] + for i, j, k in T.grip(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] Returns ------- diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py index 0413c44208b0..2cb228c357e5 100644 --- a/python/tvm/testing/plugin.py +++ b/python/tvm/testing/plugin.py @@ -74,6 +74,7 @@ def pytest_collection_modifyitems(config, items): # pylint: disable=unused-argument _count_num_fixture_uses(items) _remove_global_fixture_definitions(items) + _sort_tests(items) @pytest.fixture @@ -236,6 +237,25 @@ def _remove_global_fixture_definitions(items): delattr(module, name) +def _sort_tests(items): + """Sort tests by file/function. + + By default, pytest will sort tests to maximize the re-use of + fixtures. However, this assumes that all fixtures have an equal + cost to generate, and no caches outside of those managed by + pytest. A tvm.testing.parameter is effectively free, while + reference data for testing may be quite large. Since most of the + TVM fixtures are specific to a python function, sort the test + ordering by python function, so that + tvm.testing.utils._fixture_cache can be cleared sooner rather than + later. + + Should be called from pytest_collection_modifyitems. + + """ + items.sort(key=lambda item: item.location) + + def _target_to_requirement(target): if isinstance(target, str): target = tvm.target.Target(target) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 44006239acfd..428403a98f16 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -25,7 +25,7 @@ from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle -from .expr import Call, CallEffectKind, Let, IterVar, Any +from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 2bfa0aacb184..27cf5351a077 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -442,7 +442,7 @@ def __init__(self, dom, var, iter_type, thread_tag="", span=None): @tvm._ffi.register_object("tir.CommReducer") class CommReducer(Object): - """Communicative reduce operator + """Commutative reduce operator Parameters ---------- diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 6a90924912b1..b002ace0e400 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -108,8 +108,10 @@ def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: A = T.match_buffer(a, (m, n), "float32") B = T.match_buffer(b, (m, n), "float32") - with T.block([m, n], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(m, n): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] Then we can make it specialized with given shapes or buffers. @@ -129,8 +131,10 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") - with T.block([16, 16], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] Returns ------- diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 09a52d2e7037..786982cf704c 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -397,7 +397,8 @@ def before_fuse(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do fuse: @@ -419,9 +420,9 @@ def after_fuse(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) # the 2 loops are fused into 1 for i_j_fused in T.serial(0, 16384): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, tir.floordiv(i_j_fused, 128)) - T.bind(vj, T.floormod(i_j_fused, 128)) + with T.block("B"): + vi = T.axis.S(128, T.floordiv(i_j_fused, 128)) + vj = T.axis.S(128, T.floormod(i_j_fused, 128)) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -468,7 +469,8 @@ def before_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B") as [vi, vj]: + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do split: @@ -490,9 +492,9 @@ def after_split(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) # the original loop is split into 2 loops for i0, i1, j in T.grid(2, 64, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, ((i0*64) + i1)) - T.bind(vj, j) + with T.block("B"): + vi = T.axis.S(128, i0 * 64 + i1) + vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -529,7 +531,8 @@ def before_reorder(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do reorder: @@ -551,9 +554,8 @@ def after_reorder(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) # Here j and i are reordered for j, i in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -586,9 +588,8 @@ def before_parallel(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do parallel: @@ -609,9 +610,8 @@ def after_parallel(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.parallel(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -642,9 +642,8 @@ def before_vectorize(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do vectorize: @@ -665,9 +664,8 @@ def after_vectorize(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.serial(0, 128): for j in T.vectorized(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -706,9 +704,8 @@ def before_bind(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do bind: @@ -730,9 +727,8 @@ def after_bind(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.thread_binding(0, 128, thread = "blockIdx.x"): for j in T.thread_binding(0, 128, thread = "threadIdx.x"): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -758,9 +754,8 @@ def before_unroll(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do unroll: @@ -781,9 +776,8 @@ def after_unroll(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.unroll(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -825,7 +819,8 @@ def before_cache_read(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and cache_read: @@ -847,10 +842,12 @@ def after_cache_read(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) A_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block([128, 128], "A_local") as [vi, vj]: + with T.block("A_local"): + vi, vj = T.axis.remap("SS", [i, j]) A_local[vi, vj] = A[vi, vj] for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A_local[vi, vj] * 2.0 """ @@ -893,7 +890,8 @@ def before_cache_write(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and cache_write: @@ -915,10 +913,12 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) B_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block([128, 128], "A_local") as [vi, vj]: + with T.block("A_local"): + vi, vj = T.axis.remap("SS", [i, j]) B_local[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = B_local[vi, vj] """ @@ -974,10 +974,14 @@ def before_compute_at(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do compute-at: @@ -1000,14 +1004,12 @@ def after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in T.serial(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 """ @@ -1061,10 +1063,14 @@ def before_reverse_compute_at(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do reverse-compute-at: @@ -1087,14 +1093,12 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in T.serial(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 """ @@ -1135,10 +1139,14 @@ def before_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do compute-inline: @@ -1156,8 +1164,10 @@ def before_inline(a: T.handle, c: T.handle) -> None: def after_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ _ffi_api.ScheduleComputeInline(self, block) # type: ignore # pylint: disable=no-member @@ -1195,10 +1205,14 @@ def before_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do reverse-compute-inline: @@ -1216,8 +1230,10 @@ def before_inline(a: T.handle, c: T.handle) -> None: def after_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ _ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore # pylint: disable=no-member @@ -1384,8 +1400,9 @@ def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV: def before_rfactor(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128), - T.reduce_axis(0, 128)], "B") as [vii, vi, vj]: + for ii, i, j in T.grid(128, 128, 128): + with T.block("B"): + vii, vi, vj = T.axis.remap("SRR", [ii, i, j]) with T.init(): B[vii] = 0.0 B[vii] = B[vii] + A[vii, vi, vj] @@ -1408,14 +1425,18 @@ def after_rfactor(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128]) B_rf = T.alloc_buffer([128, 128]) - with T.block([128, 128, T.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]: - with T.init(): - B_rf[vi2, vii] = 0.0 - B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2]) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]: - with T.init(): - B[vii_1] = 0.0 - B[vii_1] = (B[vii_1] + B_rf[vi2_1, vii_1]) + for i2, ii, i in T.grid(128, 128, 128): + with T.block("B_rf"): + vi2, vii, vi = T.axis.remap("SSR", [i2, ii, i]) + with T.init(): + B_rf[vi2, vii] = 0.0 + B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2]) + for ii, i2 in T.grid(128, 128): + with T.block("B"): + vii, vi2 = T.axis.remap("SR", [ii, i2]) + with T.init(): + B[vii] = 0.0 + B[vii] = B[vii] + B_rf[vi2, vii] Note @@ -1483,10 +1504,14 @@ def before_storage_align(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do storage_align: @@ -1505,11 +1530,15 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - T.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + T.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 After lowering passes, buffer B will have strides as [129, 1]. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index f072f6b38a43..722810e9aa5b 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -628,7 +628,7 @@ def CompactBufferAllocation(): .. code-block:: python for i in range(0, 16): - with T.block([]): + with T.block(): B = T.alloc_buffer(16, 16) for j in range(0, 16): B[i, j] = A[i, j] + 1 @@ -643,7 +643,7 @@ def CompactBufferAllocation(): .. code-block:: python for i in range(0, 16): - with T.block([]): + with T.block(): B = T.alloc_buffer(1, 16) for j in range(0, 16): B[0, j] = A[i, j] + 1 @@ -715,3 +715,14 @@ def MergeDynamicSharedMemoryAllocations(): The result pass """ return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore + + +def ConvertForLoopsToSerial(): + """Convert Parallel For Loops to Serial For Loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ConvertForLoopsToSerial() # type: ignore diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 6b22cf13f5b9..e243d6ee3bc7 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -45,6 +45,7 @@ from .scan import * from .einsum import * from .unique import * +from .searchsorted import * from . import generic from . import nn from . import x86 diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index b3af36740551..0500eb55996c 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -505,15 +505,15 @@ def _callback(op): return s -@autotvm.register_topi_compute("conv2d_direct_simd.arm_cpu") -def conv2d_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype): - """Compute conv2d with SIMD (v7e-m).""" - return direct_simd.conv2d_direct_simd_compute( +@autotvm.register_topi_compute("conv2d_nhwc_direct_simd.arm_cpu") +def conv2d_nhwc_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d_nhwc with SIMD (v7e-m).""" + return direct_simd.conv2d_nhwc_direct_simd_compute( cfg, data, kernel, strides, padding, dilation, out_dtype ) -@autotvm.register_topi_schedule("conv2d_direct_simd.arm_cpu") -def schedule_conv2d_direct_simd(cfg, outs): - """Create schedule for conv2d_direct_simd""" - return direct_simd.conv2d_direct_simd_nhwc_schedule(cfg, outs) +@autotvm.register_topi_schedule("conv2d_nhwc_direct_simd.arm_cpu") +def schedule_conv2d_nhwc_direct_simd(cfg, outs): + """Create schedule for conv2d_nhwc_direct_simd""" + return direct_simd.conv2d_nhwc_direct_simd_schedule(cfg, outs) diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index c7c572c81110..cbe8644c885f 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -90,7 +90,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py b/python/tvm/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py index 307312076a7e..5ef9fd813eb2 100644 --- a/python/tvm/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py +++ b/python/tvm/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py @@ -30,7 +30,7 @@ ) -def conv2d_direct_simd(*args, **kwargs): +def conv2d_nhwc_direct_simd(*args, **kwargs): """Defines the Cortex-M7 SIMD implementation of conv2d.""" assert not kwargs, "Do not support kwargs in template function call" args = deserialize_args(args) @@ -39,17 +39,17 @@ def conv2d_direct_simd(*args, **kwargs): cfg = autotvm.get_config() args = [cfg] + args assert layout == "NHWC" - conv = conv2d_direct_simd_compute(*args) - sched = conv2d_direct_simd_nhwc_schedule(cfg, [data, kernel, conv]) + conv = conv2d_nhwc_direct_simd_compute(*args) + sched = conv2d_nhwc_direct_simd_schedule(cfg, [data, kernel, conv]) return sched, [data, kernel, conv] -conv2d_direct_simd.template_key = "direct_simd" -conv2d_direct_simd.default_data_layout = "NHWC" -conv2d_direct_simd.default_kernel_layout = "HWOI" +conv2d_nhwc_direct_simd.template_key = "direct_simd" +conv2d_nhwc_direct_simd.default_data_layout = "NHWC" +conv2d_nhwc_direct_simd.default_kernel_layout = "HWOI" -def conv2d_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): +def conv2d_nhwc_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): """Compute function for Cortex-M7 SIMD implementation of conv2d.""" assert isinstance(strides, int) or len(strides) == 2 assert isinstance(dilation, int) or len(dilation) == 2 @@ -146,7 +146,7 @@ def conv2d_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, ou return conv -def conv2d_direct_simd_nhwc_schedule(cfg, outs): +def conv2d_nhwc_direct_simd_schedule(cfg, outs): """Schedule function for Cortex-M7 SIMD implementation of conv2d.""" sched = te.create_schedule([x.op for x in outs]) diff --git a/python/tvm/topi/bifrost/conv2d.py b/python/tvm/topi/bifrost/conv2d.py index 3b6cca6aaea4..633f36c0e7ff 100644 --- a/python/tvm/topi/bifrost/conv2d.py +++ b/python/tvm/topi/bifrost/conv2d.py @@ -477,7 +477,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 21ddf57ca1d0..88d306761310 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -59,3 +59,4 @@ from .sparse_reshape import * from .transform import * from .unique import * +from .searchsorted import * diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index 4863a06b728d..3d05058ff52c 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -46,7 +46,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): data, kernel = tinfos out_dtype = out_type.dtype - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/cuda/conv3d_alter_op.py b/python/tvm/topi/cuda/conv3d_alter_op.py index faf73e77255a..c7ec7cb21fcf 100644 --- a/python/tvm/topi/cuda/conv3d_alter_op.py +++ b/python/tvm/topi/cuda/conv3d_alter_op.py @@ -35,7 +35,7 @@ def _alter_conv3d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv3d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py new file mode 100644 index 000000000000..1c39ccaa8632 --- /dev/null +++ b/python/tvm/topi/cuda/searchsorted.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""searchsorted operator for GPU""" +import tvm +from tvm import te +from .. import utils +from ..searchsorted import binary_search + + +def searchsorted(sorted_sequence, values, right, out_dtype="int64"): + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + Parameters + ---------- + sorted_sequence : te.Tensor + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : te.Tensor + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False, the index of the first suitable location found is given. If true, return the + last such index. If there is no suitable index, return either 0 or N (where N is the + size of the innermost dimension). + + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : te.Tensor + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ + + def ir(sorted_sequence, values, indices): + ib = tvm.tir.ir_builder.create() + sorted_sequence_shape = sorted_sequence.shape + values_shape = values.shape + num_search = utils.prod(values_shape) + search_range = sorted_sequence_shape[-1] + + sorted_sequence = ib.buffer_ptr(sorted_sequence) + values = ib.buffer_ptr(values) + indices = ib.buffer_ptr(indices) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr( + bx, "thread_extent", tvm.tir.indexdiv(num_search + max_threads - 1, max_threads) + ) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.if_scope(tid < num_search): + if len(sorted_sequence_shape) == 1: + sequence_offset = 0 + else: + sequence_id = tid // values_shape[-1] + sequence_offset = sequence_id * search_range + + indices[tid] = binary_search( + ib, + sequence_offset, + search_range, + sorted_sequence, + values[tid], + right, + out_dtype, + ) + + return ib.get() + + return te.extern( + values.shape, + [sorted_sequence, values], + lambda ins, outs: ir(ins[0], ins[1], outs[0]), + name="searchsorted", + dtype=out_dtype, + ) diff --git a/python/tvm/topi/intel_graphics/conv2d_alter_op.py b/python/tvm/topi/intel_graphics/conv2d_alter_op.py index 0b59a849c2c9..199d984af1e4 100644 --- a/python/tvm/topi/intel_graphics/conv2d_alter_op.py +++ b/python/tvm/topi/intel_graphics/conv2d_alter_op.py @@ -35,7 +35,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): cfg = dispatch_ctx.query(target, None) workload = cfg.workload else: - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/mali/conv2d.py b/python/tvm/topi/mali/conv2d.py index f3ef55b9a30c..051914113a5b 100644 --- a/python/tvm/topi/mali/conv2d.py +++ b/python/tvm/topi/mali/conv2d.py @@ -531,7 +531,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): data, kernel = tinfos out_dtype = out_type.dtype - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py new file mode 100644 index 000000000000..28ffd170c955 --- /dev/null +++ b/python/tvm/topi/searchsorted.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""searchsorted operator""" +from . import utils +from . import te +from ..tir import ir_builder +from .math import cast + + +def binary_search(ib, sequence_offset, search_range, sorted_sequence, value, right, out_dtype): + """Common IR generator for binary search used by CPU and GPU backends. + + `sorted_sequence` is a N-D Buffer whose innermost dimension we want to search for `value`, + and `search_range` is the size of the innermost dimension. `sequence_offset` is + a 1-D linearlized offset specifying which of innermost sequences to search. + + So the search for `value` is performed over + `sorted_sequence[sequence_offset:(sequence_offset + search_range)]`. + Note that we index N-D Buffer by 1-D linearlized indices. + + """ + lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") + hi = ib.allocate(out_dtype, (1,), name="hi", scope="local") + + lo[0] = cast(0, out_dtype) + hi[0] = cast(search_range, out_dtype) + + # Reference: pytorch/aten/src/ATen/native/cuda/Bucketization.cu + def condition(current_val, target_val): + if right: + return current_val <= target_val + return current_val < target_val + + with ib.while_loop(lo[0] < hi[0]): + mid = lo[0] + (hi[0] - lo[0] >> 1) + with ib.if_scope(condition(sorted_sequence[sequence_offset + mid], value)): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + return lo[0] + + +def searchsorted(sorted_sequence, values, right=False, out_dtype="int64"): + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + Parameters + ---------- + sorted_sequence : te.Tensor + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : te.Tensor + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False, the index of the first suitable location found is given. If true, return the + last such index. If there is no suitable index, return either 0 or N (where N is the + size of the innermost dimension). + + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : te.Tensor + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ + + def ir(sorted_sequence, values, indices): + ib = ir_builder.create() + sorted_sequence_shape = sorted_sequence.shape + values_shape = values.shape + num_search = utils.prod(values_shape) + search_range = sorted_sequence_shape[-1] + + sorted_sequence = ib.buffer_ptr(sorted_sequence) + values = ib.buffer_ptr(values) + indices = ib.buffer_ptr(indices) + + with ib.for_range(0, num_search, name="i", kind="parallel") as i: + if len(sorted_sequence_shape) == 1: + sequence_offset = 0 + else: + sequence_id = i // values_shape[-1] + sequence_offset = sequence_id * search_range + + indices[i] = binary_search( + ib, + sequence_offset, + search_range, + sorted_sequence, + values[i], + right, + out_dtype, + ) + + return ib.get() + + return te.extern( + values.shape, + [sorted_sequence, values], + lambda ins, outs: ir(ins[0], ins[1], outs[0]), + name="searchsorted", + dtype=out_dtype, + ) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index d10c49f5c084..2d7d0a4b9e11 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -73,3 +73,4 @@ from .batch_to_space_nd import batch_to_space_nd_python from .nll_loss import nll_loss from .dense import dense +from .searchsorted import searchsorted_ref diff --git a/python/tvm/topi/testing/searchsorted.py b/python/tvm/topi/testing/searchsorted.py new file mode 100644 index 000000000000..10762600992d --- /dev/null +++ b/python/tvm/topi/testing/searchsorted.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The reference implementation of searchsorted in Numpy.""" +import numpy as np + + +def searchsorted_ref(sorted_sequence, values, right, out_dtype): + """Run Numpy searchsorted on 1-D or N-D sorted_sequence.""" + side = "right" if right else "left" + if len(sorted_sequence.shape) == 1 and len(values.shape) > 1: + sorted_sequence_2d = np.tile(sorted_sequence, (np.prod(values.shape[:-1]), 1)) + else: + sorted_sequence_2d = np.reshape(sorted_sequence, (-1, sorted_sequence.shape[-1])) + + values_2d = np.reshape(values, (-1, values.shape[-1])) + indices = np.zeros(values_2d.shape, dtype=out_dtype) + + for i in range(indices.shape[0]): + indices[i] = np.searchsorted(sorted_sequence_2d[i], values_2d[i], side=side) + + return np.reshape(indices, values.shape) diff --git a/python/tvm/topi/x86/conv2d_alter_op.py b/python/tvm/topi/x86/conv2d_alter_op.py index 8e47dff37ce6..3f2df655a615 100644 --- a/python/tvm/topi/x86/conv2d_alter_op.py +++ b/python/tvm/topi/x86/conv2d_alter_op.py @@ -57,7 +57,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): cfg = dispatch_ctx.query(target, None) workload = cfg.workload else: - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 8db84497f82d..1d64261a50d7 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -35,7 +35,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): M, K = get_const_tuple(data_tensor.shape) N, _ = get_const_tuple(weight_tensor.shape) - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.dense"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 94db659e25c9..29153037b9fa 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -535,6 +535,40 @@ void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) { this->AddToSelf(other->base * scale); } +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + auto factor_str = [](int64_t f) { + return f == SplitExprNode::kPosInf ? std::string("+inf") : std::to_string(f); + }; + p->stream << "split("; + p->Print(op->index); + p->stream << ", lower=" << factor_str(op->lower_factor) + << ", upper=" << factor_str(op->upper_factor) << ", scale=" << op->scale + << ", div_mode="; + switch (op->div_mode) { + // No "default", so that the compiler will emit a warning if more div modes are + // added that are not covered by the switch. + case kTruncDiv: + p->stream << "truncdiv"; + break; + case kFloorDiv: + p->stream << "floordiv"; + break; + } + p->stream << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "sum(base=" << op->base; + for (const SplitExpr& s : op->args) { + p->stream << ", "; + p->Print(s); + } + }); + // Sub-class RewriteSimplifier::Impl to take benefit of // rewriter for condition simplification etc. class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index be78bc4aa9f9..aaf7d48b10c5 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -44,13 +45,6 @@ #include "search_policy/utils.h" #include "utils.h" -namespace tvm { -// import the function from driver_api.cc -void GetBinds(const Array& args, bool compact, - const std::unordered_map& binds, - Map* out_binds, Array* out_arg_list); -} // namespace tvm - namespace tvm { namespace auto_scheduler { @@ -1268,35 +1262,25 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i Array tensors; std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); + + // When inlining, replace const matrices with const values. + // Produces wrong IR, but good enough for feature extraction, and + // can improve the speed of feature extraction/search. Must be + // called before ScheduleToModule to have an effect. sch = sch.normalize_for_feature_extraction(); - auto bounds = te::InferBound(sch); try { - auto stmt = te::ScheduleOps(sch, bounds, false); - Map out_binds; - Array out_arg_list; - bool compact = te::VerifyCompactBuffer(stmt); const std::string& name = "main"; - GlobalVar global_var(name); - - // Copied from driver_api.cc::lower auto pass_ctx = tvm::transform::PassContext::Current(); - GetBinds(tensors, compact, std::unordered_map(), &out_binds, - &out_arg_list); - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); - bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + auto mod = ScheduleToModule(sch, Array{tensors.begin(), tensors.end()}, name, + std::unordered_map()); + bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", Bool(true)); - } - auto mod = IRModule(Map({{global_var, f}})); - if (IsGPUTask(task)) { auto pass_list = Array(); // Phase 0 @@ -1323,9 +1307,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i const auto& optimize = tir::transform::Sequential(Array{tir::transform::Simplify()}); mod = optimize(std::move(mod)); - const auto& it = mod->functions.find(global_var); - ICHECK(it != mod->functions.end()); - const auto& prim_func = (*it).second.as(); + PrimFunc prim_func = Downcast(mod->Lookup(name)); GetPerStoreFeature(prim_func->body, task->hardware_params->cache_line_bytes, max_n_bufs, feature); } catch (Error& e) { diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index bfea3e7b67c0..24cae798988e 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -42,17 +42,27 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; +using tvm::Array; +using tvm::transform::Pass; bool LLVMEnabled() { const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); return pf != nullptr; } +bool ShouldAnnotateEntryFunc(const Target target, const IRModule mod) { + const bool aot_executor = (target->GetAttr("executor").value_or("") == "aot"); + const bool single_entry_func = (mod->functions.size() == 1); + return single_entry_func && !aot_executor; +} + /*! \return The default host target for a given device target */ Target DefaultTargetHost(Target target) { if (target.defined() && target->kind->device_type == kDLCPU) { @@ -155,6 +165,13 @@ transform::Pass BindTarget(Target target) { return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {}); } +static transform::Pass AnnotateEntryFunc(bool b) { + auto fpass = [b](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true)); + }; + return tir::transform::CreatePrimFuncPass(fpass, 0, "AnnotateEntryFunc", {}); +} + template transform::Pass Filter(FCond fcond) { auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { @@ -184,7 +201,7 @@ Array CreatePassList(bool disable_loop_partition) { Array user_lower_phase2 = Array(); Array user_lower_phase3 = Array(); - // phase pasees is of the form + // phase passes is of the form // [[phase_number, pass], [phase_number, pass]... ] for (Array phase_pass : add_lower_pass) { const IntImmNode* phase_num = phase_pass[0].as(); @@ -266,24 +283,29 @@ IRModule LowerWithPassList(IRModule mod, Array pass_list) return mod; } +IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { + mod = seq(std::move(mod)); + return mod; +} + +// Convert te schedule to IRModule IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds) { - // Convert te schedule to IRModule - Array out_arg_list; - transform::PassContext pass_ctx = transform::PassContext::Current(); - sch = sch.normalize(); + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool debug_keep_trivial_loop = + pass_ctx->GetConfig("tir.debug_keep_trivial_loop", Bool(false)).value(); + // Before TIR transformation. - Map bounds = te::InferBound(sch); - tir::Stmt stmt = te::ScheduleOps(sch, std::move(bounds), false); + tir::Stmt stmt = te::ScheduleOps(sch, te::InferBound(sch), debug_keep_trivial_loop); bool compact = te::VerifyCompactBuffer(stmt); Map out_binds; + Array out_arg_list; GetBinds(args, compact, binds, &out_binds, &out_arg_list); - // Build the function - // At this point binds is only te::Tensors + // Build the function, converting from te::Tensor to tir::Buffer tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); @@ -304,7 +326,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") const Map& binds) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; - if (binds.get() != nullptr) { + if (binds.defined()) { for (auto kv : binds) { c_binds.insert({kv.first, kv.second}); } @@ -373,88 +395,97 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode); }); -std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, - const Target& target_host_arg, - const transform::PassContext& pass_ctx) { +/** + * This function takes the input module that contains both the device and host opts. + * Then, it applies transformation on the original module before splitting into separate modules for + * device and host. Then it also applies transformations on the new splitted modules. + */ +std::pair SplitMixedModule(IRModule mod_mixed, const Target& target_arg, + const Target& target_host_arg) { Target target = target_arg, target_host = target_host_arg; CheckAndUpdateHostConsistency(&target, &target_host); - Array mixed_pass_list = {BindTarget(target), - tir::transform::VerifyMemory()}; - mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); - if (pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value()) { - mixed_pass_list.push_back(tir::transform::ThreadSync("global")); - } - mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); - mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); - mixed_pass_list.push_back(tir::transform::InferFragment()); - mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); + ICHECK(mod_mixed.defined()) << "This module must be defined"; - if (target->GetAttr("unpacked-api").value_or(Bool(false))) { - mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); - } else { - mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1)); - } + mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); - mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host)); - auto opt_mixed = transform::Sequential(mixed_pass_list); - mod_mixed = opt_mixed(std::move(mod_mixed)); - - auto host_pass_list = { - Filter([](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != - CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target_host), - tir::transform::LowerTVMBuiltin(), - tir::transform::LowerCustomDatatypes(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), - tir::transform::CombineContextCall(), - }; - auto opt_host = transform::Sequential(host_pass_list); - ICHECK(mod_mixed.defined()) << "This module must be defined"; - auto mhost = opt_host(mod_mixed); - - // device pipeline - auto device_pass_list = { - Filter([](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == - CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target), - tir::transform::LowerWarpMemory(), - tir::transform::Simplify(), - tir::transform::LowerCustomDatatypes(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), - }; - auto opt_device = transform::Sequential(device_pass_list); - auto mdevice = opt_device(mod_mixed); + IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target)); - // some final misc checks. auto keys = target->GetKeys(); + + CheckAndUpdateHostConsistency(&target, &target_host); + bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); - if (target_is_gpu && mdevice->functions.size() == 0) { - LOG(WARNING) << "Specified target " << target->str() - << " but cannot find device code. Did you forget to bind?"; + if (target_is_gpu && device_mod->functions.size() == 0) { + DLOG(WARNING) << "Specified target " << target->str() + << " but cannot find device code. Did you forget to bind?"; } - if (target->kind->device_type == kDLCPU && target_host == target) { - // TODO(@jroesch): This check is no longer true we need to figure out if we care about this. - // We need to relax this check for just TIR functions. - // ICHECK(mdevice->functions.empty()) << "No device code should be generated when target " - // << "and host_target are both llvm target." - // << "\n"; + return {host_mod, device_mod}; +} + +runtime::Module PreProcessModuleForBuild(const Map& inputs_arg, + const Target& host_target) { + std::vector device_modules; + Map inputs = inputs_arg; + Target target_host = host_target; + + CheckAndUpdateHostConsistency(&inputs, &target_host); + + if (!target_host.defined()) { + for (const auto& it : inputs) { + if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type == kDLMicroDev) { + target_host = it.first; + break; + } + } + } + + if (!target_host.defined()) { + target_host = DefaultTargetHost(target_host); } - return {mhost, mdevice}; + // Update target host for all targets + CheckAndUpdateHostConsistency(&inputs, &target_host); + + IRModule mhost_all = IRModule(Map()); + + ICHECK(mhost_all.defined()) << "The host module must be defined"; + + for (const auto& it : inputs) { + if (it.second.defined()) { + auto pair = SplitMixedModule(it.second, it.first, target_host); + auto& host_mod = pair.first; + auto& device_mod = pair.second; + + ICHECK(host_mod.defined()) << "The split host module must be defined"; + + ICHECK(mhost_all.defined()) << "The host module must be defined"; + + mhost_all->Update(host_mod); + + if (device_mod->functions.size() != 0) { + device_modules.push_back(codegen::Build(device_mod, it.first)); + } + } + } + + runtime::Module complete_mod = codegen::Build(mhost_all, target_host); + for (const auto& it : device_modules) { + if (it.operator->()) { + complete_mod.Import(it); + } + } + return complete_mod; } -// Can we make this take one annotated IRModule? -// -// Build for heterogeneous execution. +TVM_REGISTER_GLOBAL("driver.preprocess_module") + .set_body_typed([](const Map& inputs_arg, Target host_target) { + return PreProcessModuleForBuild(inputs_arg, host_target); + }); + runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { auto pass_ctx = transform::PassContext::Current(); @@ -487,29 +518,41 @@ runtime::Module build(const Map& inputs_arg, const Target& tar for (const auto& it : inputs) { if (it.second.defined()) { - auto pair = SplitDevHostFuncs(it.second, it.first, target_host, pass_ctx); - auto& mhost = pair.first; - auto& mdevice = pair.second; + const Target& target = it.first; + const IRModule& ir_module = it.second; + auto pair = SplitMixedModule(ir_module, target, target_host); + auto& host_mod = pair.first; + auto& device_mod = pair.second; - ICHECK(mhost.defined()) << "The split host module must be defined"; + ICHECK(host_mod.defined()) << "The split host module must be defined"; ICHECK(mhost_all.defined()) << "The host module must be defined"; - mhost_all->Update(mhost); + // We don't want library modules going back into host codegen + // unless they're supposed to. Here if we overrode the target host + // to allow lowering previously we check that it's meant to be placed + // back into the host Module. + bool overrides_host_target = target->kind->device_type == target_host->kind->device_type; + bool non_host_target_kind = target->kind != target_host->kind; + if (overrides_host_target && non_host_target_kind) { + device_modules.push_back(codegen::Build(host_mod, it.first)); + } else { + mhost_all->Update(host_mod); + } - if (mdevice->functions.size() != 0) { - device_modules.push_back(codegen::Build(mdevice, it.first)); + if (device_mod->functions.size() != 0) { + device_modules.push_back(codegen::Build(device_mod, it.first)); } } } runtime::Module mhost = codegen::Build(mhost_all, target_host); - // Import all modules for (const auto& it : device_modules) { if (it.operator->()) { mhost.Import(it); } } + return mhost; } @@ -534,8 +577,97 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, const Target& target_host_arg) { auto target = target_arg, target_host = target_host_arg; CheckAndUpdateHostConsistency(&target, &target_host); + // More maps of target and target host Map inputs = {{target, funcs}}; return build(inputs, target_host); } +transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { + transform::PassContext pass_ctx = transform::PassContext::Current(); + + Array mixed_pass_list; + + mixed_pass_list.push_back(BindTarget(target)); + + mixed_pass_list.push_back(tir::transform::VerifyMemory()); + mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); + + if (ShouldAnnotateEntryFunc(target, mixed_mod)) { + mixed_pass_list.push_back(AnnotateEntryFunc(true)); + } + + bool detect_global_barrier = + pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); + if (detect_global_barrier) { + mixed_pass_list.push_back(tir::transform::ThreadSync("global")); + } + + mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); + mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); + mixed_pass_list.push_back(tir::transform::InferFragment()); + mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); + + if (target->GetAttr("unpacked-api").value_or(Bool(false))) { + mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); + } else { + mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1)); + } + mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + + return transform::Sequential(mixed_pass_list); +} + +TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") + .set_body_typed([](IRModule mixed_mod, Target target) { + return MixedModulePassManager(mixed_mod, target); + }); + +transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { + Array host_pass_list; + host_pass_list.push_back(Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != + CallingConv::kDeviceKernelLaunch; + })); + + ICHECK(mixed_mod.defined()) << "This module must be defined"; + + host_pass_list.push_back(BindTarget(target_host)); + + host_pass_list.push_back(tir::transform::LowerTVMBuiltin()); + host_pass_list.push_back(tir::transform::LowerCustomDatatypes()); + host_pass_list.push_back(tir::transform::LowerIntrin()); + host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); + host_pass_list.push_back(tir::transform::CombineContextCall()); + + return transform::Sequential(host_pass_list); +} + +TVM_REGISTER_GLOBAL("driver.host_mod_passes") + .set_body_typed([](IRModule mixed_mod, Target target_host) { + return HostModulePassManager(mixed_mod, target_host); + }); + +transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) { + Array device_pass_list; + device_pass_list.push_back(Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDeviceKernelLaunch; + })); + + device_pass_list.push_back(BindTarget(target)); + + device_pass_list.push_back(tir::transform::LowerWarpMemory()); + device_pass_list.push_back(tir::transform::Simplify()); + device_pass_list.push_back(tir::transform::LowerCustomDatatypes()); + device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); + device_pass_list.push_back(tir::transform::LowerIntrin()); + + return transform::Sequential(device_pass_list); +} + +TVM_REGISTER_GLOBAL("driver.device_mod_passes") + .set_body_typed([](IRModule mixed_mod, Target target_host) { + return DeviceModulePassManager(mixed_mod, target_host); + }); + } // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index 15c441d61a23..3deb70dd766c 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -170,7 +170,7 @@ Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) } LOG(FATAL) << adt << " does not contain constructor " << cons; - throw std::runtime_error("Constructor Not Found."); + return {}; } tvm::Array IRModuleNode::GetGlobalTypeVars() const { diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 3514f3228e27..316d59631782 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -276,6 +276,8 @@ class TIRTextPrinter : public StmtFunctor, std::unordered_map memo_var_; /*! \brief Map from Buffer to Doc */ std::unordered_map memo_buf_; + /*! \brief Map from Buffer to Doc */ + std::unordered_map memo_producer_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; @@ -321,7 +323,9 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitStmt_(const AssertStmtNode* op) override; Doc VisitStmt_(const StoreNode* op) override; Doc VisitStmt_(const BufferStoreNode* op) override; + Doc VisitStmt_(const ProducerStoreNode* op) override; Doc VisitStmt_(const BufferRealizeNode* op) override; + Doc VisitStmt_(const ProducerRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; @@ -342,7 +346,9 @@ class TIRTextPrinter : public StmtFunctor, Doc PrintIterVar(const IterVarNode* op); Doc PrintRange(const RangeNode* op); Doc PrintBuffer(const BufferNode* op); + Doc PrintProducer(const DataProducerNode* op); Doc BufferNode2Doc(const BufferNode* op, Doc doc); + Doc DataProducerNode2Doc(const DataProducerNode* op, Doc doc); Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } Doc PrintBufferRegion(const BufferRegionNode* op); @@ -361,6 +367,7 @@ class TIRTextPrinter : public StmtFunctor, Doc GetUniqueName(std::string prefix); Doc AllocVar(const Var& var); Doc AllocBuf(const Buffer& buffer); + Doc AllocProducer(const DataProducer& buffer); /*! * \brief special method to render vectors of docs with a separator * \param vec vector of docs @@ -372,6 +379,9 @@ class TIRTextPrinter : public StmtFunctor, String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false); +String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate); + } // namespace tir } // namespace tvm diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index fa132f079793..302c4491cebe 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -65,6 +65,8 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) { return PrintRange(node.as()); } else if (node->IsInstance()) { return PrintBuffer(node.as()); + } else if (node->IsInstance()) { + return PrintProducer(node.as()); } else if (node->IsInstance()) { return PrintString(node.as()); } else if (node->IsInstance()) { @@ -199,6 +201,19 @@ Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) { } } +Doc TIRTextPrinter::PrintProducer(const DataProducerNode* op) { + const DataProducer& prod = GetRef(op); + + if (meta_->InMeta(prod)) { + return meta_->GetMetaNode(prod); + } else if (memo_producer_.count(prod)) { + return memo_producer_[prod]; + } else { + memo_producer_[prod] = AllocProducer(prod); + return DataProducerNode2Doc(op, memo_producer_[prod]); + } +} + Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { doc << Doc::Text(": Buffer(") << Print(buf->data) << ", " << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", " << Print(buf->strides); @@ -220,6 +235,11 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { return doc << ")"; } +Doc TIRTextPrinter::DataProducerNode2Doc(const DataProducerNode* prod, Doc doc) { + return doc << Doc::Text(": DataProducer(") << Print(prod->GetNameHint()) << ", " + << PrintDType(prod->GetDataType()) << ", " << Print(prod->GetShape()) << ")"; +} + Doc TIRTextPrinter::PrintBufferRegion(const BufferRegionNode* op) { Doc doc; doc << Print(op->buffer) << "["; @@ -439,6 +459,12 @@ Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const ProducerStoreNode* op) { + Doc doc; + doc << Print(op->producer) << Print(op->indices) << " = " << Print(op->value); + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { Doc doc; doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", " @@ -446,6 +472,13 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const ProducerRealizeNode* op) { + Doc doc; + doc << "producer_realize(" << Print(op->producer) << ", " << Print(op->bounds) << ", " + << Print(op->condition) << ", " << PrintBody(op->body) << ")"; + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; auto scope = GetPtrStorageScope(op->buffer_var); @@ -709,6 +742,20 @@ Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) { return val; } +Doc TIRTextPrinter::AllocProducer(const DataProducer& producer) { + const auto& it = memo_producer_.find(producer); + if (it != memo_producer_.end()) { + return it->second; + } + std::string name = producer->GetNameHint(); + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "tensor_" + name; + } + Doc val = GetUniqueName(name); + memo_producer_[producer] = val; + return val; +} + Doc TIRTextPrinter::PrintSep(const std::vector& vec, const Doc& sep) { Doc seq; if (vec.size() != 0) { diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index fa74e56f491c..d82ad74fd5c3 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -22,10 +22,10 @@ * \brief Printer class to print Tensor IR to python syntax script */ -#include #include #include #include +#include #include #include #include @@ -91,7 +91,7 @@ class TVMScriptPrinter : public StmtFunctor, */ TVM_DLL Doc Print(const ObjectRef& node); - private: + protected: /*! \brief The tir prefix */ String tir_prefix_; /*! \brief whether show meta data */ @@ -119,8 +119,6 @@ class TVMScriptPrinter : public StmtFunctor, std::unordered_map memo_buf_; /*! \brief Map from Buffer to Declaration Doc */ std::unordered_map memo_buf_decl_; - /*! \brief Map from CommReducer to Doc */ - std::unordered_map memo_reducer_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; /*! \brief number of children of current node's parent */ @@ -128,7 +126,17 @@ class TVMScriptPrinter : public StmtFunctor, /*! \brief the number of current node */ int current_num_; /*! \brief loop stack without annotations */ - std::vector loop_stack_; + std::vector simple_loop_stack_; + /*! \brief the maps from loop_vars to the loops */ + std::unordered_map loop_var_map_; + /*! + * \brief simple block vars remap from loop vars + * simple_remap requires: + * 1. block var iter type is kDataPar or kCommReduce + * 2. value is a single Var, which is a loop_var outside the block + * 3. The iter range is equal to loop range + */ + std::vector> block_var_remaps_; Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override; @@ -193,11 +201,15 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintArray(const ArrayNode* op); Doc PrintBuffer(const BufferNode* op); Doc AllocBufferDeclaration(const Buffer& buf); - Doc PrintBlockVar(const BlockNode* op); + Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value); + Doc PrintBlockVarRemaps(); + Doc PrintBlockVars(const BlockRealizeNode* op); Doc PrintBlockAttr(const BlockRealizeNode* op); Doc PrintBlockBody(const BlockNode* op); + virtual Doc PrintBlockName(const BlockNode* block_op); Doc PrintBufferRegion(const BufferRegionNode* op); Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op); + Doc PrintCommReducer(const CommReducerNode* op); Doc PrintAnnotations(const Map& annotations); static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } @@ -205,15 +217,24 @@ class TVMScriptPrinter : public StmtFunctor, Doc AllocVar(const Var& var); Doc AllocBuf(const Buffer& buffer); void TryDeallocVar(const Var& var); + bool ContainsOptionalInfo(const Stmt& stmt); /*! Helper functions for loop printing. */ /*! * \brief Print a single for loop * \param loop The for loop to be printed */ - Doc PrintLoop(const For& loop); + virtual Doc PrintLoop(const For& loop); /*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */ Doc PrintLoopStack(); + /*! + * \brief Print all simple loops in stack into one line using tir_prefix_.grid(). + * \param for_op the for node to be checked + */ + bool IsSimpleLoop(const ForNode* for_op) { + return for_op->kind == ForKind::kSerial && for_op->annotations.empty() && + is_zero(for_op->min) && !ContainsOptionalInfo(GetRef(for_op)); + } /*! * \brief Print additional info about expr in comment. @@ -222,11 +243,9 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintOptionalInfo(const Stmt& stmt) { Doc doc; // default annotations - if (annotate_ != nullptr) { + if (ContainsOptionalInfo(stmt)) { std::string annotated_stmt = annotate_(stmt); - if (!annotated_stmt.empty()) { - doc << "# " << annotated_stmt << Doc::NewLine(); - } + doc << "# " << annotated_stmt << Doc::NewLine(); } return doc; } @@ -379,6 +398,16 @@ Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) { return val; } +/*! + * \brief Check if any optional information exists in annotate_ for + * a given Stmt. + * \param stmt The statement. + */ +bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) { + if (annotate_ == nullptr) return false; + return !annotate_(stmt).empty(); +} + /*! * \brief Try to dealloc vars out of space and leave the index to coming vars. * \note It is not a necessary step. @@ -415,6 +444,39 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) { return doc; } +Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) { + Doc doc; + int n_var = static_cast(op->rhs.size()); + + doc << tir_prefix_ << ".comm_reducer(lambda "; + for (const Var& v_lhs : op->lhs) { + doc << Print(v_lhs) << ", "; + } + for (int i = 0; i < n_var; ++i) { + doc << Print(op->rhs[i]) << (i == n_var - 1 ? ": " : ", "); + } + if (n_var == 1) { + doc << Print(op->result[0]) << ", "; + } else { + doc << "("; + for (int i = 0; i < n_var; ++i) { + doc << Print(op->result[i]); + if (i != n_var - 1) { + doc << ", "; + } + } + doc << "), "; + } + doc << Print(op->identity_element) << ")"; + + // Remove the vars in `lhs` and `rhs`, because they are the parameters of the printed lambda. + for (int i = 0; i < n_var; ++i) { + memo_var_.erase(op->lhs[i]); + memo_var_.erase(op->rhs[i]); + } + return doc; +} + Doc TVMScriptPrinter::Print(const ObjectRef& node) { if (!node.defined()) return Doc::Text("None"); if (node->IsInstance()) { @@ -442,6 +504,8 @@ Doc TVMScriptPrinter::Print(const ObjectRef& node) { return PrintBufferRegion(node.as()); } else if (node->IsInstance()) { return PrintMatchBufferRegion(node.as()); + } else if (node->IsInstance()) { + return PrintCommReducer(node.as()); } else { LOG(FATAL) << "Do not know how to print " << node->GetTypeKey(); return Doc(); @@ -821,21 +885,23 @@ Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) { Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { Doc doc; var_not_in_headers_.insert(op->loop_var.get()); + loop_var_map_[op->loop_var.get()] = GetRef(op); const auto* body = op->body.as(); - bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() && is_zero(op->min); - if (simple_loop) loop_stack_.push_back(GetRef(op)); + bool simple_loop = IsSimpleLoop(op); + if (simple_loop) simple_loop_stack_.push_back(GetRef(op)); // It is a loop that can be compressed, let the loops below print it out - if (simple_loop && body != nullptr) { - Doc result = Print(GetRef(body)); + if (simple_loop && body != nullptr && IsSimpleLoop(body)) { + doc << Print(GetRef(body)); TryDeallocVar(op->loop_var); - return result; + loop_var_map_.erase(op->loop_var.get()); + return doc; } // It is a loop that can not be compressed - bool print_above = !loop_stack_.empty(); + bool print_above = !simple_loop_stack_.empty(); // print loops above if needed if (print_above) { doc << PrintLoopStack(); - loop_stack_.clear(); + simple_loop_stack_.clear(); } if (!simple_loop) { // print current loop if needed @@ -847,6 +913,7 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } TryDeallocVar(op->loop_var); + loop_var_map_.erase(op->loop_var.get()); return doc; } @@ -901,52 +968,100 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } -Doc TVMScriptPrinter::PrintBlockVar(const BlockNode* op) { +/*! Helper functions for block printing. */ +Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) { + Doc doc; + doc << Print(iter_var->var) << " = " << tir_prefix_ << ".axis."; + switch (iter_var->iter_type) { + case kDataPar: + doc << "spatial"; + break; + case kCommReduce: + doc << "reduce"; + break; + case kOrdered: + doc << "scan"; + break; + case kOpaque: + doc << "opaque"; + break; + default: + LOG(FATAL) << "Unknown block var iter type: " << iter_var->iter_type; + break; + } + doc << "("; + const Range& dom = iter_var->dom; + if (is_zero(dom->min)) { + doc << Print(dom->extent); + } else { + doc << "(" << Print(dom->min) << ", " << Print(dom->min + dom->extent) << ")"; + } + doc << ", " << Print(value) << ")"; + return doc; +} + +Doc TVMScriptPrinter::PrintBlockVarRemaps() { + ICHECK(!block_var_remaps_.empty()); + if (block_var_remaps_.size() == 1) { + const IterVar& iter_var = block_var_remaps_[0].first; + const PrimExpr& value = block_var_remaps_[0].second; + return PrintBlockVar(iter_var, value); + } Doc doc; - doc << "with " << tir_prefix_ << ".block(["; - std::vector block_var_docs; - for (const auto& iter_var : op->iter_vars) { - Doc block_var_doc; - if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) { - block_var_doc << Print(iter_var->dom->extent); + std::vector iter_vars, iter_values; + std::string iter_type; + for (const auto& pair : block_var_remaps_) { + const IterVar& iter_var = pair.first; + const PrimExpr& value = pair.second; + iter_vars.push_back(Print(iter_var->var)); + iter_values.push_back(Print(value)); + if (iter_var->iter_type == kDataPar) { + iter_type += "S"; + } else if (iter_var->iter_type == kCommReduce) { + iter_type += "R"; } else { - block_var_doc << tir_prefix_ << "."; - switch (iter_var->iter_type) { - case kDataPar: - block_var_doc << "range"; - break; - case kCommReduce: - block_var_doc << "reduce_axis"; - break; - case kOrdered: - block_var_doc << "scan_axis"; - break; - case kOpaque: - block_var_doc << "opaque_axis"; - break; - default: - LOG(FATAL) << "Unknown block var iter type: " << iter_var->iter_type; - break; - } - block_var_doc << "(" << Print(iter_var->dom->min) << ", " - << Print(iter_var->dom->min + iter_var->dom->extent) << ")"; + ICHECK(false); } - block_var_docs.push_back(block_var_doc); } - doc << PrintSep(block_var_docs, Doc::Text(", ")) << "]"; - if (!op->name_hint.empty()) { - doc << ", " << Doc::StrLiteral(op->name_hint); - } - doc << ")"; - std::vector block_var_names; - for (const auto& iter_var : op->iter_vars) { + doc << PrintSep(iter_vars, Doc::Text(", ")) << " = " << tir_prefix_ << ".axis.remap(" + << Doc::StrLiteral(iter_type) << ", [" << PrintSep(iter_values, Doc::Text(", ")) << "])"; + return doc; +} + +Doc TVMScriptPrinter::PrintBlockVars(const BlockRealizeNode* op) { + Doc doc; + const auto* block_op = op->block.as(); + ICHECK_EQ(block_op->iter_vars.size(), op->iter_values.size()); + tir::ExprDeepEqual expr_equal; + + auto is_simple_remap = [this, &expr_equal](const IterVar& iter_var, + const PrimExpr& value) -> bool { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) return false; + if (!value->IsInstance()) return false; + const Var& var = Downcast(value); + auto it = loop_var_map_.find(var.get()); + return it != loop_var_map_.end() && expr_equal(it->second->min, iter_var->dom->min) && + expr_equal(it->second->extent, iter_var->dom->extent); + }; + + for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { + const IterVar& iter_var = block_op->iter_vars[i]; + const PrimExpr& value = op->iter_values[i]; var_not_in_headers_.insert(iter_var->var.get()); - block_var_names.push_back(Print(iter_var->var)); + if (is_simple_remap(iter_var, value)) { + block_var_remaps_.push_back(std::make_pair(iter_var, value)); + } else { + if (!block_var_remaps_.empty()) { + doc << Doc::NewLine() << PrintBlockVarRemaps(); + block_var_remaps_.clear(); + } + doc << Doc::NewLine() << PrintBlockVar(iter_var, value); + } } - if (!block_var_names.empty()) { - doc << " as [" << PrintSep(block_var_names, Doc::Text(", ")) << "]"; + if (!block_var_remaps_.empty()) { + doc << Doc::NewLine() << PrintBlockVarRemaps(); + block_var_remaps_.clear(); } - doc << ":"; return doc; } @@ -957,10 +1072,6 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) { if (!is_one(op->predicate)) { block_attr_doc << Doc::NewLine() << tir_prefix_ << ".where(" << Print(op->predicate) << ")"; } - for (size_t i = 0; i < block_op->iter_vars.size(); ++i) - block_attr_doc << Doc::NewLine() << tir_prefix_ << ".bind(" - << Print(block_op->iter_vars[i]->var) << ", " << Print(op->iter_values[i]) - << ")"; block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads(" << Print(block_op->reads) << ")"; block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes(" << Print(block_op->writes) << ")"; if (!block_op->annotations.empty()) { @@ -991,15 +1102,31 @@ Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) { return body; } +/*! + * \brief Print the name of a block + * \param block_op The block node to be printed + */ +Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) { + Doc doc; + doc << "with " << tir_prefix_ << ".block("; + if (!block_op->name_hint.empty()) { + doc << Doc::StrLiteral(block_op->name_hint); + } + doc << "):"; + return doc; +} + Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { const auto* block_op = op->block.as(); + Doc doc = PrintOptionalInfo(GetRef(block_op)); // print block name and block vars - Doc doc = PrintBlockVar(block_op); + doc << PrintBlockName(block_op); + Doc block_var = PrintBlockVars(op); // print predicate, binding, read/write tensor region, annotations Doc block_attr_doc = PrintBlockAttr(op); // print body Doc body = PrintBlockBody(block_op); - doc << Doc::Indent(4, block_attr_doc << Doc::NewLine() << body); + doc << Doc::Indent(4, block_var << block_attr_doc << Doc::NewLine() << body); for (const auto& iter_var : block_op->iter_vars) { TryDeallocVar(iter_var->var); } @@ -1060,7 +1187,6 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { memo_var_.clear(); memo_buf_.clear(); memo_buf_decl_.clear(); - memo_reducer_.clear(); var_not_in_headers_.clear(); buf_not_in_headers_.clear(); // print signature @@ -1085,15 +1211,6 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { body << Print((*it).first) << ", " << memo_buf_decl_[(*it).second]; body << ")" << Doc::NewLine(); } - // print comm_reducer - for (const auto& it : memo_reducer_) { - body << it.second << " = .comm_reducer("; - var_not_in_headers_.insert(it.first->lhs[0].get()); - var_not_in_headers_.insert(it.first->rhs[0].get()); - body << "lambda " << Print(it.first->lhs[0]) << ", " << Print(it.first->rhs[0]) << ": " - << Print(it.first->result[0]) << ", " << Print(it.first->identity_element[0]); - body << ")" << Doc::NewLine(); - } // print body body << "# body" << Doc::NewLine(); if (op->body->IsInstance() && @@ -1265,11 +1382,11 @@ Doc TVMScriptPrinter::PrintLoop(const For& loop) { Doc TVMScriptPrinter::PrintLoopStack() { Doc res; - if (loop_stack_.size() == 1) { - res << PrintLoop(loop_stack_[0]); - } else if (loop_stack_.size() > 1) { + if (simple_loop_stack_.size() == 1) { + res << PrintLoop(simple_loop_stack_[0]); + } else if (simple_loop_stack_.size() > 1) { std::vector vars, extents; - for (const auto& loop : loop_stack_) { + for (const auto& loop : simple_loop_stack_) { vars.push_back(Print(loop->loop_var)); extents.push_back(Print(loop->extent)); } @@ -1279,6 +1396,45 @@ Doc TVMScriptPrinter::PrintLoopStack() { return res; } +/*! + * \brief The printer for TVMScript with diagnostic + * \details The printer obtain the precedence of the top-level operation when printing each + * subexpression to decide whether or not parentheses is needed. + */ +class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { + public: + explicit TVMScriptPrinterWithDiagnostic(const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate) + : TVMScriptPrinter(tir_prefix, show_meta, annotate) {} + + protected: + Doc PrintBlockName(const BlockNode* block_op) override; + Doc PrintUnderline(const Stmt& stmt, int length); + Doc PrintLoop(const For& loop) override; +}; + +Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) { + Doc doc = TVMScriptPrinter::PrintBlockName(block_op); + doc << PrintUnderline(GetRef(block_op), doc.str().size()); + return doc; +} + +Doc TVMScriptPrinterWithDiagnostic::PrintUnderline(const Stmt& stmt, int length) { + Doc doc; + // annotation + if (ContainsOptionalInfo(stmt)) { + String underline = std::string(length, '^'); + doc << Doc::NewLine() << underline; + } + return doc; +} + +Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) { + Doc res = TVMScriptPrinter::PrintLoop(loop); + res << PrintUnderline(loop, res.str().size()); + return res; +} + String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) { ICHECK(mod->IsInstance() || mod->IsInstance()); return TVMScriptPrinter(tir_prefix, show_meta).Print(mod).str() + "\n"; @@ -1286,5 +1442,13 @@ String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_met TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript); +String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate) { + ICHECK(mod->IsInstance() || mod->IsInstance()); + return TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod).str() + "\n"; +} + +TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic); + } // namespace tir } // namespace tvm diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 22e2e9a71040..1421906a3bbb 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -131,12 +131,11 @@ class TypeSolver::Unifier : public TypeFunctor { Type resolved = this->VisitType(rhs->resolved_type, lhs->resolved_type); if (!resolved.defined()) { - solver_->diag_ctx_.Emit( - Diagnostic::Error(this->span) - << "The Relay type checker is unable to show the following types match.\n" - << "In particular " - << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" - << PrettyPrint(rhs->resolved_type) << "`"); + solver_->Emit(Diagnostic::Error(this->span) + << "The Relay type checker is unable to show the following types match.\n" + << "In particular " + << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" + << PrettyPrint(rhs->resolved_type) << "`"); return lhs->resolved_type; } else { TypeNode* top = solver_->GetTypeNode(resolved); @@ -233,11 +232,10 @@ class TypeSolver::Unifier : public TypeFunctor { tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { - this->solver_->diag_ctx_.Emit(Diagnostic::Error(this->span) - << "tensor type `" << PrettyPrint(tt1) << "` has " - << tt1->shape.size() << " dimensions, while `" - << PrettyPrint(tt2) << "` has " << tt2->shape.size() - << " dimensions"); + this->solver_->Emit(Diagnostic::Error(this->span) + << "tensor type `" << PrettyPrint(tt1) << "` has " << tt1->shape.size() + << " dimensions, while `" << PrettyPrint(tt2) << "` has " + << tt2->shape.size() << " dimensions"); return Type(nullptr); } @@ -266,7 +264,7 @@ class TypeSolver::Unifier : public TypeFunctor { err << "dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch) << " does not match " << std::get<2>(mismatch) << "."; } - this->solver_->diag_ctx_.Emit(err); + this->solver_->Emit(err); return Type(nullptr); } @@ -526,7 +524,7 @@ class TypeSolver::Merger : public TypeFunctor { // constructor TypeSolver::TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx) : reporter_(make_object(this)), - current_func(current_func), + current_func_(current_func), diag_ctx_(diag_ctx), module_(diag_ctx->module) { ICHECK(module_.defined()); @@ -618,7 +616,7 @@ bool TypeSolver::Solve() { rnode->resolved = resolved; } catch (const CompileError& err) { - this->diag_ctx_.Emit(Diagnostic::Error(rnode->span) << err.what()); + this->Emit(Diagnostic::Error(rnode->span) << err.what()); rnode->resolved = false; } catch (const Error& e) { ICHECK(false) << e.what(); diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index 56cea60ceeda..3bde1a1e3746 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -94,7 +94,7 @@ class TypeSolver { * \brief Report a diagnostic. * \param diag The diagnostic to report. */ - void EmitDiagnostic(const Diagnostic& diag); + void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); } private: class OccursChecker; @@ -176,13 +176,9 @@ class TypeSolver { /*! \brief Reporter that reports back to self */ TypeReporter reporter_; /*! \brief The global representing the current function. */ - GlobalVar current_func; - - public: + GlobalVar current_func_; /*! \brief The diagnostic context. */ DiagnosticContext diag_ctx_; - - private: /*! \brief The module. */ IRModule module_; diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 38eb6aa6a07e..3c9c35c4f254 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -182,9 +182,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { * \return The corresponding token. */ StorageInfo GetStorage(const Expr& expr) { - auto props = GetOnDeviceProps(expr); // See through "on_device" calls. - Expr true_expr = props.body.defined() ? props.body : expr; + Expr true_expr = IgnoreOnDevice(expr); VisitExpr(true_expr); auto it = storage_device_map_.find(true_expr); ICHECK(it != storage_device_map_.end()); @@ -440,31 +439,33 @@ class AOTExecutorCodegen : public MixedModeVisitor { void VisitExpr_(const LetNode* op) override { // TODO(giuseros): support Let nodes in AOT - CHECK(false) << "Let not yet implemented in AOT"; + LOG(FATAL) << "Let not yet implemented in AOT"; } void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } void VisitExpr_(const OpNode* op) override { - throw std::runtime_error("can not compile op in non-eta expanded form"); + LOG(FATAL) << "All OpNodes should have been expanded"; + } + void VisitExpr_(const IfNode* op) override { + LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called"; } - void VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); } void VisitExpr_(const FunctionNode* op) override { ICHECK(op->GetAttr(attr::kCompiler).defined()) << "FunctionNode only supported by custom codegen"; } void VisitExpr_(const RefCreateNode* op) override { - throw std::invalid_argument("reference not supported"); + LOG(FATAL) << "AOT executor does not support references (found RefCreateNode)"; } void VisitExpr_(const RefReadNode* op) override { - throw std::invalid_argument("reference not supported"); + LOG(FATAL) << "AOT executor does not support references (found RefReadNode)"; } void VisitExpr_(const RefWriteNode* op) override { - throw std::invalid_argument("reference not supported"); + LOG(FATAL) << "AOT executor does not support references (found RefWriteNode)"; } void VisitExpr_(const ConstructorNode* op) override { - throw std::invalid_argument("ADT constructor case not yet implemented"); + LOG(FATAL) << "AOT executor does not support ADTs (found ConstructorNode)"; } void VisitExpr_(const MatchNode* op) override { - throw std::invalid_argument("match case not yet implemented"); + LOG(FATAL) << "AOT executor does not support matching (found MatchNode)"; } // Create the main PrimFunc to execute the graph. Please note that diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ef82ed617508..7005e94c2411 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -33,7 +33,7 @@ #include "../../target/func_registry_generator.h" #include "../../target/source/codegen_source_base.h" -#include "compile_engine.h" +#include "te_compiler.h" #include "utils.h" namespace tvm { @@ -295,8 +295,6 @@ class RelayBuildModule : public runtime::ModuleNode { executor_ = executor; CheckAndUpdateHostConsistency(&targets_, &target_host_); BuildRelay(mod, params_, mod_name); - // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096. - CompileEngine::Global()->Clear(); } protected: diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc deleted file mode 100644 index 0e7af2278375..000000000000 --- a/src/relay/backend/compile_engine.cc +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/compile_engine.cc - * \brief Internal compilation engine. - */ -#include "compile_engine.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "../../runtime/meta_data.h" -#include "../transforms/pass_utils.h" -#include "te_compiler_cache.h" -#include "utils.h" - -namespace tvm { -namespace relay { - -TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); - -class CompileEngineImpl : public CompileEngineNode { - public: - // Lower the function. - CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) { - return LowerInternal(key, mangle_fn)->cached_func; - } - - CachedFunc Lower(const CCacheKey& key, const String mod_name) { - auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); }; - - return Lower(key, mangle_fn); - } - - // For now, build one module per function. - PackedFunc JIT(const CCacheKey& key) final { - auto mangle_fn = [](String name) { return name; }; - CCacheValue value = LowerInternal(key, mangle_fn); - if (value->packed_func != nullptr) return value->packed_func; - auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); - value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); - return value->packed_func; - } - - CachedFunc LowerShapeFunc(const CCacheKey& key) final { - return LowerShapeFuncInternal(key)->cached_func; - } - - Array LowerExternalFunctions() { - Array ret; - std::unordered_map cached_symbol; - std::vector cached_ext_funcs; - for (const auto& it : cache_) { - auto src_func = it.first->source_func; - ICHECK(src_func.defined()); - - if (src_func->GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->GetAttr(attr::kCompiler); - ICHECK(code_gen.defined()) << "No external codegen is set"; - std::string code_gen_name = code_gen.value(); - cached_ext_funcs.push_back(it.first); - - auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" - << AsText(src_func, false) << "\n" - << "Functions with external codegen must have the " - << tvm::attr::kGlobalSymbol << " attr set."; - - std::string sn = symbol_name.value(); - if (!cached_symbol.count(sn)) { - cached_symbol[sn] = code_gen_name; - } else { - ICHECK_NE(cached_symbol[sn], code_gen_name) - << "Found duplicated symbol: " << sn << " for: " << code_gen_name; - } - - std::string ext_name = "relay.ext." + code_gen_name; - auto pf = tvm::runtime::Registry::Get(ext_name); - ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; - // No need to keep compiler attribute at this point, functions have been - // extracted for specific codegen. - src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); - runtime::Module ext_mod = (*pf)(src_func); - - // todo(@zhiics, @jroesch): Should this be a user visible error? - ICHECK(ext_mod.defined()) << "No external library was generated for " << ext_name - << "even though it was requested" - "by the annotated function " - << PrettyPrint(src_func); - - ret.push_back(ext_mod); - } - } - - // No need to cache external functions as we collected them all to create - // external runtime modules. - for (const auto& it : cached_ext_funcs) { - cache_.erase(it); - } - return ret; - } - - void Clear() final { cache_.clear(); } - - // List all items in the cache. - Array ListItems() { - std::lock_guard lock(mutex_); - Array items; - for (auto& kv : cache_) { - items.push_back(kv.first); - items.push_back(kv.second); - } - return items; - } - - // List all items in the shape_func_cache. - Array ListShapeFuncItems() { - std::lock_guard lock(mutex_); - Array items; - for (auto& kv : shape_func_cache_) { - items.push_back(kv.first); - items.push_back(kv.second); - } - return items; - } - - /*! - * \brief Get the cache key of the function that is being lowered currently - * \return the cache key - */ - CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; } - - private: - // implement lowered func - CCacheValue LowerInternal(const CCacheKey& key, std::function mangle_fn) { - std::lock_guard lock(mutex_); - CCacheValue value; - auto it = cache_.find(key); - if (it != cache_.end()) { - it->second->use_count += 1; - if (it->second->cached_func.defined()) return it->second; - value = it->second; - } else { - value = CCacheValue(make_object()); - value->use_count = 0; - if (!backend::IsCompileEngineCacheDisabled()) { - cache_[key] = value; - } - } - cur_ccache_key_ = key; - - // No need to lower external functions for now. We will invoke the external - // codegen tool once and lower all functions together. - if (key->source_func->GetAttr(attr::kCompiler).defined()) { - auto ir_module = IRModule(); - const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(name_node.defined()) << "External function has not been attached a name yet."; - auto func_name = std::string(name_node.value()); - auto target = Target("ext_dev"); - auto global_var = GlobalVar(func_name); - global_var->checked_type_ = key->source_func->checked_type(); - ir_module->Add(global_var, key->source_func); - value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); - return value; - } - - // Enforce use the target. - With target_scope(key->target); - - ICHECK(!value->cached_func.defined()); - auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) { - return GetUniqueName(mangle_fn(name), &name_map_); - }); - - // Skip lowering for device copy node. - const Expr body = (key->source_func)->body; - if (const CallNode* call_node = body.as()) { - if (call_node->attrs.as()) { - value->cached_func = cfunc; - return value; - } - } - - // NOTE: array will copy on write. - Array all_args = Array(cfunc->inputs); - for (te::Tensor arg : cfunc->outputs) { - all_args.push_back(arg); - } - // lower the function - std::unordered_map binds; - auto func_name = cfunc->prim_fn_var->name_hint; - cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); - value->cached_func = cfunc; - - return value; - } - - // implement lowered shape func - CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { - std::lock_guard lock(mutex_); - CCacheValue value; - auto it = shape_func_cache_.find(key); - if (it != shape_func_cache_.end()) { - it->second->use_count += 1; - if (it->second->cached_func.defined()) return it->second; - value = it->second; - } else { - value = CCacheValue(make_object()); - value->use_count = 0; - shape_func_cache_[key] = value; - } - // Enforce use the target. - With target_scope(key->target); - - ICHECK(!value->cached_func.defined()); - using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); - - auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { - return GetUniqueName(name, &name_map_); - }); - - value->cached_func = cached_func; - return value; - } - - /*! \brief compiler cache lock*/ - std::mutex mutex_; - /*! \brief internal name map to get an unique name */ - std::unordered_map name_map_; - /*! \brief internal compiler cache */ - std::unordered_map cache_; - /*! \brief internal compiler cache for shape funcs */ - std::unordered_map shape_func_cache_; - /*! \brief the cache key of the function that is being lowered currently*/ - CCacheKey cur_ccache_key_; -}; - -/*! \brief The global compile engine */ -CompileEngine& CompileEngine::Global() { - // intentionally allocate raw pointer to avoid - // free during destructuion. - static CompileEngine* inst = new CompileEngine(make_object()); - return *inst; -} - -TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.disable_compile_engine_cache", Bool); - -TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") - .set_body_typed([](tvm::Array outputs, OpImplementation impl) { - return LoweredOutput(outputs, impl); - }); - -TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") - .set_body_typed([](Function source_func, Target target) { - return CCacheKey(source_func, target); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() { - return CompileEngine::Global(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](CompileEngine self) { - self->Clear(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") - .set_body_typed([](CompileEngine self, CCacheKey key, const String mod_name) { - return self->Lower(key, mod_name); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") - .set_body_typed([](CompileEngine self) { return self->LowerExternalFunctions(); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->JIT(key); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->ListItems(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListShapeFuncItems") - .set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->ListShapeFuncItems(); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGetCurrentCCacheKey") - .set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->GetCurrentCCacheKey(); - }); - -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h deleted file mode 100644 index 4afdc6d30485..000000000000 --- a/src/relay/backend/compile_engine.h +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/compile_engine.h - * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. - * - * This layer represents the older design of the Relay compilation flow and is being deprecated - * in favor of te_compiler.h which is a migration step towards a standard pass based lowering of - * Relay functions. - * - */ -#ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ -#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "te_compiler_cache.h" - -namespace tvm { -namespace relay { - -using namespace tvm::relay::tec; - -/*! - * \brief Backend compilation engine for - * low level code generation. - */ -class CompileEngineNode : public Object { - public: - /*! \brief destructor */ - virtual ~CompileEngineNode() {} - /*! - * \brief Get lowered result. - * \param key The key to the cached function. - * \param mod_name The mangling function for mangling names. - * \return The result. - */ - virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; - - /*! - * \brief Get lowered result. - * \param key The key to the cached function. - * \param mod_name The module name to mangle the functions. - * \return The result. - */ - virtual CachedFunc Lower(const CCacheKey& key, const String mangle_fn) = 0; - /*! - * \brief Just in time compile to get a PackedFunc. - * \param key The key to the cached function. - * \return The result. - */ - virtual PackedFunc JIT(const CCacheKey& key) = 0; - /*! - * \brief Lower the shape function. - * \param key The key to the cached function. - * \return The result. - */ - virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; - /*! - * \brief Lower the external function using external codegen tools. - * \return The runtime moduels for each needed external codegen tool. - */ - virtual tvm::Array LowerExternalFunctions() = 0; - - /*! \brief clear the cache. */ - virtual void Clear() = 0; - - // VisitAttrs - void VisitAttrs(AttrVisitor*) {} - - static constexpr const char* _type_key = "relay.CompileEngine"; - TVM_DECLARE_FINAL_OBJECT_INFO(CompileEngineNode, Object); -}; - -/*! \brief cache entry used in compile engine */ -class CompileEngine : public ObjectRef { - public: - CompileEngine() {} - explicit CompileEngine(ObjectPtr n) : ObjectRef(n) {} - CompileEngineNode* operator->() { return static_cast(get_mutable()); } - using ContainerType = CompileEngineNode; - /*! \brief The global compile engine. */ - TVM_DLL static CompileEngine& Global(); -}; - -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index 97b308e51e18..3e675215e7e0 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -606,25 +606,37 @@ std::pair, std::vector> EthosnCompiler::GetInput return std::make_pair(input_order, output_order); } -auto ctx = transform::PassContext::Current(); -auto cfg = ctx -> GetConfig("relay.ext.ethos-n.options").defined() - ? ctx -> GetConfig("relay.ext.ethos-n.options") - : AttrsWithDefaultValues(); -auto m_Queries = sl::SupportQueries(sl::GetFwAndHwCapabilities( - sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); +std::unique_ptr EthosnCompiler::m_Queries; + +EthosnError EthosnCompiler::SupportedSetup() { + if (m_Queries == nullptr) { + auto ctx = transform::PassContext::Current(); + auto cfg = ctx->GetConfig("relay.ext.ethos-n.options").defined() + ? ctx->GetConfig("relay.ext.ethos-n.options") + : AttrsWithDefaultValues(); + m_Queries = std::make_unique(sl::GetFwAndHwCapabilities( + sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); + if (m_Queries == nullptr) { + return EthosnError("Could not initialise Ethos-N compiler isSupported"); + } + } + return EthosnError(); +} TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; ConvolutionParams params; auto err = EthosnAPI::QnnConv2d(call, ¶ms); + err += EthosnCompiler::SupportedSetup(); if (params.is_depthwise) { *rv = !err && - m_Queries.IsDepthwiseConvolutionSupported(params.bias_info, params.weights_info, - params.conv_info, params.activation_info); + EthosnCompiler::GetSupported()->IsDepthwiseConvolutionSupported( + params.bias_info, params.weights_info, params.conv_info, params.activation_info); } else { - *rv = !err && m_Queries.IsConvolutionSupported(params.bias_info, params.weights_info, - params.conv_info, params.activation_info); + *rv = !err && + EthosnCompiler::GetSupported()->IsConvolutionSupported( + params.bias_info, params.weights_info, params.conv_info, params.activation_info); } }); @@ -633,8 +645,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc") Call call = args[0]; FullyConnectedParams params; auto err = EthosnAPI::QnnFullyConnected(call, ¶ms); - *rv = !err && m_Queries.IsFullyConnectedSupported(params.bias_info, params.weights_info, - params.fc_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsFullyConnectedSupported( + params.bias_info, params.weights_info, params.fc_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") @@ -642,7 +655,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") Call call = args[0]; MaxPool2DParams params; auto err = EthosnAPI::MaxPool2D(call, ¶ms); - *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsPoolingSupported(params.pool_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") @@ -650,7 +665,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") Call call = args[0]; AvgPool2DParams params; auto err = EthosnAPI::AvgPool2D(call, ¶ms); - *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsPoolingSupported(params.pool_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") @@ -658,7 +675,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") Call call = args[0]; ReshapeParams params; auto err = EthosnAPI::Reshape(call, ¶ms); - *rv = !err && m_Queries.IsReshapeSupported(params.new_shape, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsReshapeSupported(params.new_shape, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") @@ -666,8 +685,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") Call call = args[0]; AdditionParams params; auto err = EthosnAPI::Addition(call, ¶ms); - *rv = !err && m_Queries.IsAdditionSupported(params.lhs_info, params.rhs_info, - params.output_quantization_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsAdditionSupported( + params.lhs_info, params.rhs_info, params.output_quantization_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") @@ -675,7 +695,8 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") Call call = args[0]; SigmoidParams params; auto err = EthosnAPI::Sigmoid(call, ¶ms); - *rv = !err && m_Queries.IsSigmoidSupported(params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsSigmoidSupported(params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") @@ -683,7 +704,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") Call call = args[0]; ConcatenateParams params; auto err = EthosnAPI::Concatenate(call, ¶ms); - *rv = !err && m_Queries.IsConcatenationSupported(params.input_infos, params.concat_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsConcatenationSupported(params.input_infos, + params.concat_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") @@ -691,7 +714,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") Call call = args[0]; SplitParams params; auto err = EthosnAPI::Split(call, ¶ms); - *rv = !err && m_Queries.IsSplitSupported(params.input_info, params.split_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsSplitSupported(params.input_info, params.split_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") @@ -699,7 +724,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") Call call = args[0]; DepthToSpaceParams params; auto err = EthosnAPI::DepthToSpace(call, ¶ms); - *rv = !err && m_Queries.IsDepthToSpaceSupported(params.input_info, params.depth_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsDepthToSpaceSupported(params.input_info, + params.depth_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") @@ -707,7 +734,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") Call call = args[0]; ReluParams params; auto err = EthosnAPI::Relu(call, ¶ms); - *rv = !err && m_Queries.IsReluSupported(params.relu_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsReluSupported(params.relu_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.query").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index 63ae7a3e4704..ca2df05e958d 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -287,6 +287,22 @@ class EthosnCompiler { */ static runtime::Module CreateRuntimeModule(const ObjectRef& ref); + /*! + * \brief Initialise the is-supported functionality of the Ethos-N support library + * with the target variant. + * \return Error object + */ + static EthosnError SupportedSetup(); + + /*! + * \brief Return the is-supported API of the Support Library + * \return A reference to the API. + */ + static std::unique_ptr& GetSupported() { + ICHECK(m_Queries != nullptr); + return m_Queries; + } + private: /*! * \brief Compile a single Relay Ethos-N function into an ordered compiled network. @@ -322,6 +338,8 @@ class EthosnCompiler { */ static std::pair, std::vector> GetInputOutputOrder( NetworkWithIDs network, const std::unique_ptr& compiled_network); + + static std::unique_ptr m_Queries; }; runtime::Module CompileEthosn(const ObjectRef& ref) { diff --git a/src/relay/backend/contrib/ethosu/to_te_graph.cc b/src/relay/backend/contrib/ethosu/to_te_graph.cc deleted file mode 100644 index 9646c39da089..000000000000 --- a/src/relay/backend/contrib/ethosu/to_te_graph.cc +++ /dev/null @@ -1,234 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/contrib/ethosu/to_te_graph.cc - * \brief Lower a Relay function to a TE graph. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "../../compile_engine.h" -#include "../../utils.h" - -namespace tvm { -namespace relay { -namespace contrib { -namespace ethosu { - -/*! \brief Node container to represent a Tensor Expression graph. */ -class TEGraphNode : public Object { - public: - /* \brief The inputs to the graph */ - tvm::Array inputs; - /* \brief The outputs to the graph */ - tvm::Array outputs; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("inputs", &inputs); - v->Visit("outputs", &outputs); - } - - static constexpr const char* _type_key = "relay.TEGraph"; - TVM_DECLARE_FINAL_OBJECT_INFO(TEGraphNode, Object); -}; - -class TEGraph : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TEGraph, ObjectRef, TEGraphNode); -}; - -TVM_REGISTER_NODE_TYPE(TEGraphNode); - -Array GetShape(const Array& shape) { - // for now, we always use int32 shape when possible - // even if the result of shape inference becomes int64. - Array res; - for (IndexExpr val : shape) { - const int64_t* pval = tir::as_const_int(val); - if (pval != nullptr) { -#ifndef TVM_INDEX_DEFAULT_I64 - ICHECK_LE(pval[0], std::numeric_limits::max()); - ICHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(IntImm(DataType::Int(32), *pval)); -#else - res.push_back(val); -#endif // TVM_INDEX_DEFAULT_I64 - } else if (val->IsInstance()) { - res.push_back(val.as()->ToVar()); - } else { - res.push_back(val); - } - } - return res; -} - -class RelayToTE : public backend::MemoizedExprTranslator> { - public: - RelayToTE() = default; - - TEGraph Lower(const Function& prim_func) { - auto graph_node = make_object(); - for (Var param : prim_func->params) { - Array inputs; - if (const auto* ttype = param->checked_type().as()) { - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - graph_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - ICHECK(ttype != nullptr); - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - graph_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } - } - memo_[param] = inputs; - } - graph_node->outputs = this->VisitExpr(prim_func->body); - return TEGraph(graph_node); - } - - Array VisitExpr_(const VarNode* op) final { - LOG(FATAL) << "Free variable " << op->name_hint(); - return {}; - } - - Array VisitExpr_(const ConstantNode* op) final { - using tir::make_const; - ICHECK(op->is_scalar()); - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "compile_engine_const", topi::kBroadcast); - return {value}; - } - - Array VisitExpr_(const CallNode* call_node) final { - static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); - ICHECK(flower_call) << "relay.backend.lower_call is not registered."; - - Array inputs; - int count_tuple = 0; - for (Expr arg : call_node->args) { - if (arg->checked_type().as()) { - ++count_tuple; - } - for (te::Tensor tensor : VisitExpr(arg)) { - inputs.push_back(tensor); - } - } - if (count_tuple) { - ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; - } - - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; - Op op = Downcast(call_node->op); - - Array outputs; - LoweredOutput lowered_out = - (*flower_call)(GetRef(call_node), inputs, tvm::Target("llvm")); - outputs = lowered_out->outputs; - - if (outputs.size() != 1) { - const auto* tuple_type = call_node->checked_type().as(); - ICHECK(tuple_type) << "Expect output to be a tuple type"; - ICHECK_EQ(tuple_type->fields.size(), outputs.size()); - } - return outputs; - } - - Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; - return Array(); - } - - Array VisitExpr_(const LetNode* op) final { - Array val = VisitExpr(op->value); - ICHECK(!memo_.count(op->var)); - memo_[op->var] = val; - return VisitExpr(op->body); - } - - Array VisitExpr_(const TupleNode* op) final { - Array fields; - for (Expr field : op->fields) { - ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; - Array res = VisitExpr(field); - ICHECK_EQ(res.size(), 1); - fields.push_back(res[0]); - } - return fields; - } - - Array VisitExpr_(const TupleGetItemNode* op) final { - const auto* tuple_type = op->tuple->type_as(); - Array tuple = VisitExpr(op->tuple); - ICHECK_EQ(tuple_type->fields.size(), tuple.size()); - ICHECK_GE(op->index, 0); - ICHECK_LT(static_cast(op->index), tuple.size()); - return {tuple[op->index]}; - } -}; - -TVM_REGISTER_GLOBAL("relay.backend.contrib.ethosu.LowerToTE") - .set_body_typed([](Function prim_func) { return RelayToTE().Lower(prim_func); }); - -} // namespace ethosu -} // namespace contrib -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 6d332803041d..cae20210ec4f 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -33,7 +33,9 @@ namespace example_target_hooks { class ConvertAddToSubtract : public MixedModeMutator { public: explicit ConvertAddToSubtract(IRModule ir_module, Target host_target) - : ir_module_(ir_module), host_target_(host_target) {} + : ir_module_(ir_module), + host_target_(host_target), + custom_target_(Target("example_target_hook")) {} IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); @@ -81,7 +83,15 @@ class ConvertAddToSubtract : public MixedModeMutator { tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), buffer_map, DictAttrs(dict_attrs)); - replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + + // Switch to TIRToRuntime hook for testing + Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); + if (tir_to_runtime) { + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_); + } else { + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + } + ir_module_->Add(new_global_var, replacement_func); } @@ -109,6 +119,7 @@ class ConvertAddToSubtract : public MixedModeMutator { public: IRModule ir_module_; Target host_target_; + Target custom_target_; }; transform::Pass RelayToTIR() { @@ -124,8 +135,4 @@ transform::Pass RelayToTIR() { } // namespace contrib } // namespace relay -TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) - .set_attr("RelayToTIR", - relay::contrib::example_target_hooks::RelayToTIR()); - } // namespace tvm diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc new file mode 100644 index 000000000000..75b161ad4499 --- /dev/null +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -0,0 +1,39 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +namespace tvm { + +namespace relay { +namespace contrib { +namespace example_target_hooks { +tvm::transform::Pass RelayToTIR(); +runtime::Module TIRToRuntime(IRModule mod, Target target); +} // namespace example_target_hooks +} // namespace contrib +} // namespace relay + +TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) + .set_attr("RelayToTIR", relay::contrib::example_target_hooks::RelayToTIR()) + .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime); + +} // namespace tvm diff --git a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc new file mode 100644 index 000000000000..36d801d349a7 --- /dev/null +++ b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../../../../target/source/codegen_c_host.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace example_target_hooks { + +using namespace tir; + +class CodeGenExampleTargetHook : public codegen::CodeGenCHost { + public: + /*! + * \brief Emit code that changes adds to multiplies for testing + */ + void VisitExpr_(const SubNode* op, std::ostream& os) final { + os << '('; + PrintExpr(op->a, os); + os << " * "; + PrintExpr(op->b, os); + os << ')'; + } +}; + +runtime::Module TIRToRuntime(IRModule mod, Target target) { + bool output_ssa = false; + bool emit_asserts = false; + CodeGenExampleTargetHook codegen; + Array function_names; + codegen.Init(output_ssa, emit_asserts, target->str()); + for (auto kv : mod->functions) { + auto prim_func = Downcast(kv.second); + auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + function_names.push_back(global_symbol.value()); + codegen.AddFunction(prim_func); + } + std::string code = codegen.Finish(); + return codegen::CSourceModuleCreate(code, "c", function_names); +} + +} // namespace example_target_hooks +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index dbe14b63293f..debd669126c4 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -473,15 +473,15 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorindex]}; } std::vector VisitExpr_(const OpNode* op) override { - throw std::runtime_error("can not compile op in non-eta expanded form"); + LOG(FATAL) << "All OpNodes should have been expanded"; return {}; } std::vector VisitExpr_(const GlobalVarNode* op) override { - throw std::runtime_error(""); + LOG(FATAL) << "All GlobalVarNodes should be removed before graph executor's Codegen is called"; return {}; } std::vector VisitExpr_(const IfNode* op) override { - throw std::invalid_argument("if not supported"); + LOG(FATAL) << "Graph executor does not support control flow (found IfNode)"; return {}; } std::vector VisitExpr_(const FunctionNode* op) override { @@ -490,23 +490,23 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator VisitExpr_(const RefCreateNode* op) override { - throw std::invalid_argument("reference not supported"); + LOG(FATAL) << "Graph executor does not support references (found RefCreateNode)"; return {}; } std::vector VisitExpr_(const RefReadNode* op) override { - throw std::invalid_argument("reference not supported"); + LOG(FATAL) << "Graph executor does not support references (found RefReadNode)"; return {}; } std::vector VisitExpr_(const RefWriteNode* op) override { - throw std::invalid_argument("reference not supported"); + LOG(FATAL) << "Graph executor does not support references (found RefWriteNode)"; return {}; } std::vector VisitExpr_(const ConstructorNode* op) override { - throw std::invalid_argument("ADT constructor case not yet implemented"); + LOG(FATAL) << "Graph executor does not support ADTs (found ConstructorNode)"; return {}; } std::vector VisitExpr_(const MatchNode* op) override { - throw std::invalid_argument("match case not yet implemented"); + LOG(FATAL) << "Graph executor does not support matching (found MatchNode)"; return {}; } /*! diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 7642f3ccf703..961252a14fa7 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -146,10 +146,9 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { * \return The corresponding token. */ const std::vector& GetToken(const Expr& expr) { - this->VisitExpr(expr); // See through on_device calls. - auto props = GetOnDeviceProps(expr); - Expr real_expr = props.body.defined() ? props.body : expr; + Expr real_expr = IgnoreOnDevice(expr); + this->VisitExpr(real_expr); auto it = token_map_.find(real_expr.get()); ICHECK(it != token_map_.end()) << "Expression not found in storage map:" << std::endl << PrettyPrint(real_expr); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index ef89fd9c9c6c..a596e09907d5 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -37,7 +37,7 @@ #include "../op/annotation/annotation.h" #include "../transforms/pass_utils.h" -#include "./te_compiler.h" +#include "te_compiler.h" namespace tvm { namespace relay { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 445602540dbb..a8c27a126032 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -313,6 +313,45 @@ TECompiler::TECompiler() { data_ = object; } +/*! \brief The global TE compiler */ +TECompiler& TECompiler::Global() { + static TECompiler* inst = new TECompiler(make_object()); + return *inst; +} +TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() { + return TECompiler::Global(); +}); + +TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") + .set_body_typed([](Function source_func, Target target) { + return CCacheKey(source_func, target); + }); + +TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") + .set_body_typed([](tvm::Array outputs, OpImplementation impl) { + return LoweredOutput(outputs, impl); + }); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerClear").set_body_typed([](TECompiler self) { + self->Clear(); +}); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerLower") + .set_body_typed([](TECompiler self, CCacheKey key, const String mod_name) { + return self->Lower(key, mod_name); + }); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerJIT") + .set_body_typed([](TECompiler self, CCacheKey key) { return self->JIT(key); }); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerListItems").set_body_typed([](TECompiler self) { + TECompilerImpl* ptr = dynamic_cast(self.operator->()); + ICHECK(ptr != nullptr); + return ptr->ListItems(); +}); + using AnalysisRemapping = std::unordered_map; std::tuple IsDeviceCopy(const Function& func) { diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 248fd40f98eb..e3b7d46457ad 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -127,6 +127,7 @@ class TECompiler : public ObjectRef { explicit TECompiler(ObjectPtr n) : ObjectRef(n) {} TECompilerNode* operator->() { return static_cast(get_mutable()); } using ContainerType = TECompilerNode; + TVM_DLL static TECompiler& Global(); }; /*! @@ -193,7 +194,7 @@ IRModule LowerTE( * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower - * \returns The pass which lowers primative functions to TIR + * \returns The pass which lowers primitive functions to TIR */ transform::Pass LowerTEPass(TargetMap targets, const String& module_name, std::function process_fn); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index d0e83765928a..be5b172e6a7c 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -111,8 +111,10 @@ Array GetShape(const Array& shape) { // Construct a schedule for a given Relay primitive function and target. class ScheduleBuilder : public backend::MemoizedExprTranslator> { public: - explicit ScheduleBuilder(Target target) - : target_(target), device_copy_op_(Op::Get("device_copy")) { + explicit ScheduleBuilder(Target target, bool create_schedule = true) + : target_(target), + device_copy_op_(Op::Get("device_copy")), + create_schedule_(create_schedule) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); } @@ -132,6 +134,8 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator auto outputs = this->VisitExpr(prim_func->body); auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; + // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME + // whenever the value of kMaxFuncNameLength changes if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; truncated_name << candidate_name.substr(0, kMaxFuncNameLength); @@ -149,7 +153,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator auto prim_fn_var = GlobalVar(prim_fn_name); prim_fn_var->checked_type_ = prim_func->checked_type(); - ICHECK(anchor_op_.defined()); // Fusion over tupled results may leave identity relationships // between inputs and outputs, and those should not be scheduled. // Hence schedule only non PlaceholderOp outputs. @@ -162,7 +165,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator te::Schedule schedule; // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr) { + if (anchor_attrs_.as() == nullptr && create_schedule_) { if (use_auto_scheduler_) { const auto* fauto_schedule = runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); @@ -259,17 +262,19 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator impl = lowered_out->implementation; } - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern >= anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - anchor_implementation_ = impl; + if (create_schedule_) { + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; + anchor_implementation_ = impl; + } } if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as(); @@ -334,6 +339,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. const Op& device_copy_op_; + bool create_schedule_; }; /*! @@ -390,6 +396,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> // Generate a name. auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; + // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME + // whenever the value of kMaxFuncNameLength changes if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; truncated_name << candidate_name.substr(0, kMaxFuncNameLength); @@ -667,6 +675,12 @@ std::string GetUniqueName(std::string name, std::unordered_map return name; } +TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { + return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) { + return name; + }); +}); + } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 47ba96b2c77e..7975ef873173 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -62,7 +62,6 @@ struct LoweredOutputNode : public Object { v->Visit("outputs", &outputs); v->Visit("implementation", &implementation); } - static constexpr const char* _type_key = "relay.LoweredOutput"; TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); }; diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 6d59b858927c..a647aa1a3fd2 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -327,7 +327,7 @@ inline relay::Function BindParamsByName( for (auto arg : func->params) { const auto& name = arg->name_hint(); if (name_dict.count(name)) { - repeat_var.insert(arg); + repeat_var.insert(name_dict[name]); } else { name_dict[name] = arg; } @@ -427,15 +427,6 @@ inline bool IsAutoSchedulerEnabled() { .value(); } -/*! - * \brief Return whether the compile engine cache is disabled in the pass context. - */ -inline bool IsCompileEngineCacheDisabled() { - return transform::PassContext::Current() - ->GetConfig("relay.backend.disable_compile_engine_cache", Bool(false)) - .value(); -} - /*! * \brief Get the sequence of Relay optimization passes based on backend type. * The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 36cd0c7f406d..b3c1cd81274f 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -594,8 +594,9 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { auto offset_register = last_register_; // If the shape is constant then we will emit a static tensor allocation - // instruction. - auto const_shape = args[2].as(); + // instruction. It may be wrapped by an on_device, but it will be on the host + // which is assumed by the alloc_tensor instruction anyway. + auto const_shape = AsIgnoringOnDevice(args[2]); if (const_shape) { NDArray shape = const_shape->data; @@ -619,7 +620,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { this->VisitExpr(args[0]); auto size_register = last_register_; - ICHECK(args[1].as()); + ICHECK(args[1].as()); // Always a literal. NDArray alignment_arr = args[1].as()->data; ICHECK_EQ(alignment_arr->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got " @@ -824,7 +825,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { /*! * \brief Compile a pattern match expression - * It first converts the pattern match expression into a desicision tree, the condition + * It first converts the pattern match expression into a decision tree, the condition * could be object comparison or variable binding. If any of the condition fails in a clause, * the decision tree switches to check the conditions of next clause and so on. If no clause * matches the value, a fatal node is inserted. diff --git a/src/relay/op/algorithm/searchsorted.cc b/src/relay/op/algorithm/searchsorted.cc new file mode 100644 index 000000000000..be5921311660 --- /dev/null +++ b/src/relay/op/algorithm/searchsorted.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file searchsorted.cc + * \brief SearchSorted operators + */ +#include +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(SearchSortedAttrs); + +bool SearchSortedRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const SearchSortedAttrs* param = attrs.as(); + ICHECK_EQ(types.size(), 3); + const auto* sorted_sequence = types[0].as(); + const auto* values = types[1].as(); + ICHECK(sorted_sequence) << "Expects TensorType in the first input"; + ICHECK(values) << "Expects TensorType in the second input"; + ICHECK_GT(values->shape.size(), 0) << "The rank of `values` must be greater than one"; + + if (sorted_sequence->shape.size() > 1) { + ICHECK_EQ(sorted_sequence->shape.size(), values->shape.size()) + << "Ranks of `sorted_sequence` and values must be the same if `sorted_sequence` is " + "multi-dimensional."; + + for (size_t i = 0; i < values->shape.size() - 1; ++i) { + if (sorted_sequence->shape[i].as() && values->shape[i].as()) { + ICHECK_EQ(sorted_sequence->shape[i].as()->value, + values->shape[i].as()->value) + << "`sorted_sequence and `values` do not have the same shape along outer axes"; + } + } + } + + reporter->Assign(types[2], TensorType(values->shape, param->dtype)); + return true; +} + +Expr MakeSearchSorted(Expr sorted_sequence, Expr values, Bool right, DataType dtype) { + auto attrs = make_object(); + static const Op& op = Op::Get("searchsorted"); + attrs->dtype = dtype; + attrs->right = right; + return Call(op, {sorted_sequence, values}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.searchsorted").set_body_typed(MakeSearchSorted); + +RELAY_REGISTER_OP("searchsorted") + .describe( + R"doc(Find indices where elements should be inserted to maintain order. +If `sorted_sequence` is N-dimensional, the innermost dimension of +`values` are searched in the corresponding dimension of `sorted_sequence`. +)doc" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("sorted_sequence", "Tensor", + "Monotonically increasing sequence on the innermost dimension.") + .add_argument("values", "Tensor", "Values to search for.") + .set_support_level(6) + .add_type_rel("SearchSorted", SearchSortedRel); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index beadf4a67ddc..8b00839cda33 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -94,12 +94,7 @@ RELAY_REGISTER_OP("on_device") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("TNonComputational", true) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_type) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_attr("TNonComputational", true); OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { if (call_node->op == OnDeviceOp()) { diff --git a/src/relay/op/annotation/annotation.h b/src/relay/op/annotation/annotation.h index b6dff8813fd4..d772df9b023a 100644 --- a/src/relay/op/annotation/annotation.h +++ b/src/relay/op/annotation/annotation.h @@ -85,6 +85,32 @@ OnDeviceProps GetOnDeviceProps(const CallNode* call_node); */ OnDeviceProps GetOnDeviceProps(const Expr& expr); +/*! + * \brief Returns the body of \p expr if it is an "on_device" annotation, otherwise returns + * \p expr directly. + */ +inline Expr IgnoreOnDevice(const Expr& expr) { + OnDeviceProps props = GetOnDeviceProps(expr); + return props.body.defined() ? props.body : expr; +} + +/*! + * \brief Returns \p expr as \p NodeType, or null if it is not of that type. Looks through + * any "on_device" annotations. + */ +template +const NodeType* AsIgnoringOnDevice(const Expr& expr) { + const auto* node = expr.as(); + if (node != nullptr) { + return node; + } + OnDeviceProps props = GetOnDeviceProps(expr); + if (!props.body.defined()) { + return nullptr; + } + return props.body.as(); +} + /*! * \brief Returns \p function annotated with "param_device_types" and "result_device_type" * attributes capturing parameter and result devices types respectively. diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc new file mode 100644 index 000000000000..fa73645d45de --- /dev/null +++ b/src/relay/op/contrib/ethosu/depthwise.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/contrib/ethosu/depthwise.cc + * \brief Depthwise convolution 2D operator definition for the Arm(R) Ethos(TM)-U NPU + */ +#include +#include +#include +#include +#include + +#include "../../../qnn/utils.h" +#include "../../nn/convolution.h" +#include "common.h" + +namespace tvm { +namespace relay { +namespace op { +namespace contrib { +namespace ethosu { + +/*! \brief Attributes used by the Ethos(TM)-U NPU depthwise operator */ +struct EthosuDepthwiseConv2DAttrs : public tvm::AttrsNode { + double ifm_scale; + int ifm_zero_point; + int weight_zero_point; + double ofm_scale; + int ofm_zero_point; + Array kernel_shape; + IndexExpr ofm_channels; + Array strides; + Array padding; + Array dilation; + String activation; + int clip_min; + int clip_max; + String upscale; + String ifm_layout; + String ofm_layout; + + TVM_DECLARE_ATTRS(EthosuDepthwiseConv2DAttrs, "relay.attrs.EthosuDepthwiseConv2DAttrs") { + TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ifm_zero_point) + .describe("The quantization zero point for the Output Feature Map tensor."); + TVM_ATTR_FIELD(weight_zero_point) + .describe("The quantization zero point for the weight tensor."); + TVM_ATTR_FIELD(ofm_scale).describe("The quantization scale for the Output Feature Map tensor."); + TVM_ATTR_FIELD(ofm_zero_point) + .describe("The quantization zero point for the Output Feature Map tensor."); + TVM_ATTR_FIELD(kernel_shape) + .describe("The 2 dimensional kernel shape as (kernel_height, kernel_width).") + .set_default(NullValue >()); + TVM_ATTR_FIELD(ofm_channels) + .describe("The number of OFM channels.") + .set_default(NullValue()); + TVM_ATTR_FIELD(strides) + .describe("The 2 dimensional strides as (stride_height, stride_width).") + .set_default(Array({1, 1})); + TVM_ATTR_FIELD(padding) + .describe("The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right)") + .set_default(Array({0, 0, 0, 0})); + TVM_ATTR_FIELD(dilation) + .describe("The 2 dimensional dilation as (dilation_height, dilation_width).") + .set_default(Array({1, 1})); + TVM_ATTR_FIELD(activation) + .describe( + "Description: The activation function to use." + "'NONE' - no activation function." + "'CLIP' - clip the output between clip_min and clip_max." + "'TANH - tanh activation function." + "'SIGMOID' - sigmoid activation function." + "'LUT' - use a look-up table to perform the activation function.") + .set_default("NONE"); + TVM_ATTR_FIELD(clip_min) + .describe("The minimum clipping value if activation = CLIP.") + .set_default(0); + TVM_ATTR_FIELD(clip_max) + .describe("The maximum clipping value if activation = CLIP.") + .set_default(0); + TVM_ATTR_FIELD(upscale) + .describe( + "The 2x2 upscaling mode to apply to the Input Feature Map tensor. " + "'NONE' - no upscaling. " + "'NEAREST' - upscale using nearest neighbour. " + "'ZEROS' - upscale using zeros.") + .set_default("NONE"); + TVM_ATTR_FIELD(ifm_layout) + .set_default("NHWC") + .describe("The layout of the Input Feature Map tensor. Can be 'NHWC' or 'NHCWB16'."); + TVM_ATTR_FIELD(ofm_layout) + .set_default("NHWC") + .describe("The layout of the Output Feature Map tensor. Can be 'NHWC' or 'NHCWB16'."); + } +}; + +TVM_REGISTER_NODE_TYPE(EthosuDepthwiseConv2DAttrs); + +bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 5); + const auto* ifm = types[0].as(); + const auto* weight = types[1].as(); + const auto* scale_bias = types[2].as(); + if (ifm == nullptr || weight == nullptr) return false; + + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr."; + ICHECK(ifm->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8)) + << "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for ifm but was " + << ifm->dtype; + ICHECK(weight->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8)) + << "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for weight but was " + << weight->dtype; + ICHECK(scale_bias->dtype == DataType::UInt(8)) + << "Expected ethosu_depthwise_conv2d type(uint8) for scale_bias but was " + << scale_bias->dtype; + + // Collect the ifm, weight and ofm tensors for using in the inference function + Array tensor_types = {types[0], types[1], types[4]}; + + // Assign weight type {ofm_channels, kernel_height, kernel_width, 1} + reporter->Assign(types[1], TensorType({param->ofm_channels, param->kernel_shape[0], + param->kernel_shape[1], weight->shape[3]}, + weight->dtype)); + + // Assign ofm type + auto ofm_shape = + EthosuInferKernelOutput(ifm->shape, param->ifm_layout, param->ofm_layout, param->kernel_shape, + param->ofm_channels, param->dilation, param->strides, param->padding); + + reporter->Assign(types[4], TensorType(ofm_shape, ifm->dtype)); + + return true; +} + +Expr MakeEthosuDepthwiseConv2D(Expr ifm, Expr weight, Expr scale_bias, Expr lut, double ifm_scale, + int ifm_zero_point, int weight_zero_point, double ofm_scale, + int ofm_zero_point, Array kernel_shape, + IndexExpr ofm_channels, Array strides, + Array padding, Array dilation, + String activation, int clip_min, int clip_max, String upscale, + String ifm_layout, String ofm_layout) { + auto attrs = make_object(); + attrs->ifm_scale = ifm_scale; + attrs->ifm_zero_point = ifm_zero_point; + attrs->weight_zero_point = weight_zero_point; + attrs->ofm_scale = ofm_scale; + attrs->ofm_zero_point = ofm_zero_point; + attrs->kernel_shape = std::move(kernel_shape); + attrs->ofm_channels = std::move(ofm_channels); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->activation = std::move(activation); + attrs->clip_min = clip_min; + attrs->clip_max = clip_max; + attrs->upscale = std::move(upscale); + attrs->ifm_layout = std::move(ifm_layout); + attrs->ofm_layout = std::move(ofm_layout); + static const Op& op = Op::Get("contrib.ethosu.depthwise_conv2d"); + return Call(op, {ifm, weight, scale_bias, lut}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.ethosu_depthwise_conv2d") + .set_body_typed(MakeEthosuDepthwiseConv2D); + +RELAY_REGISTER_OP("contrib.ethosu.depthwise_conv2d") + .describe(R"code(Arm(R) Ethos(TM)-U NPU 2D quantized depthwise operator. + +This Relay operator corresponds to the hardware-implemented quantized +depthwise operation found on Ethos(TM)-U NPUs. It accepts either NHWC or NHCWB16 format +for the input data (input feature map, or IFM) and OHWI format for the kernel weights. + +- **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels) + NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16) +- **weight**: (ofm_channels, kernel_shape[0], kernel_shape[1], 1 (depth multiplier)) +- **scale_bias**: (ofm_channels, 10) +- **ofm**: (1, ofm_height, ofm_width, ofm_channels) + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(4) + .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).") + .add_argument("weight", "Tensor", "The weight tensor.") + .add_argument("scale_bias", "Tensor", "The packed per-channel weight scale and bias tensor.") + .add_argument("lut", "Tensor", "The look-up table values to use if activation = 'LUT'") + .set_support_level(11) + .add_type_rel("EthosuDepthwiseConv2D", EthosuDepthwiseConv2DRel); + +} // namespace ethosu +} // namespace contrib +} // namespace op +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 5339d48e3a2f..08e92b31965e 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -91,12 +91,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("TNonComputational", true) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype, Array assert_shape) { @@ -106,13 +101,9 @@ Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype, attrs->assert_shape = assert_shape; } else { // Look through any on_device for the shape argument expression. - Expr literal_shape = shape; - auto props = GetOnDeviceProps(literal_shape); - if (props.body.defined()) { - // See through on_device calls. - literal_shape = props.body; - } - attrs->const_shape = Downcast(literal_shape); + const auto* constant_node = AsIgnoringOnDevice(shape); + ICHECK(constant_node); + attrs->const_shape = GetRef(constant_node); } static const Op& op = Op::Get("memory.alloc_tensor"); return Call(op, {storage, offset, shape}, Attrs(attrs), {}); @@ -206,12 +197,7 @@ RELAY_REGISTER_OP("memory.alloc_tensor") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("TNonComputational", true) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); bool KillRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -230,12 +216,7 @@ RELAY_REGISTER_OP("memory.kill") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("TNonComputational", true) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); static void FlattenTupleTypeAux(const Type& type, std::vector* out) { if (auto tt = type.as()) { diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 0d40caa15052..cf44b308ce02 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -49,8 +49,13 @@ InferCorrectLayoutOutput PoolInferCorrectLayout(const Attrs& attrs, ICHECK(attrs_ptr); ObjectPtr params = make_object(*attrs_ptr); - if (new_in_layouts.defined()) { - // Set the pool with the new layout. + if (params->out_layout != "") { + // when users specify the out_layout of pooling, follow user's preference + ICHECK_EQ(params->layout, params->out_layout) + << "Pooling input/output layouts mismatch: " << params->layout << " vs. " + << params->out_layout; + } else if (new_in_layouts.defined()) { + // the pooling is using an inferred layout (i.e., new_in_layouts[0]) given by relay caller ICHECK_EQ(new_in_layouts.size(), 1); params->layout = new_in_layouts[0].name(); } @@ -144,6 +149,7 @@ Array Pool2DCompute(const Attrs& attrs, const Array& inp auto padding = param->padding; auto ceil_mode = param->ceil_mode; Layout layout(param->layout); + Layout out_layout(param->out_layout); ICHECK(tir::BijectiveLayout(layout, kNCHW).defined()) << "max_pool2d currently only supports layouts that are convertible from NCHW"; @@ -178,9 +184,9 @@ Array Pool2DCompute(const Attrs& attrs, const Array& inp TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d") .set_body_typed([](Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode) { + String out_layout, bool ceil_mode) { return MakeMaxPool(data, pool_size, strides, dilation, padding, layout, - ceil_mode, "nn.max_pool2d"); + out_layout, ceil_mode, "nn.max_pool2d"); }); RELAY_REGISTER_OP("nn.max_pool2d") @@ -216,9 +222,9 @@ RELAY_REGISTER_OP("nn.max_pool2d") TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d") .set_body_typed([](Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode, bool count_include_pad) { + String out_layout, bool ceil_mode, bool count_include_pad) { return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, - ceil_mode, count_include_pad, "nn.avg_pool2d"); + out_layout, ceil_mode, count_include_pad, "nn.avg_pool2d"); }); RELAY_REGISTER_OP("nn.avg_pool2d") @@ -303,9 +309,10 @@ Array GlobalPool2DCompute(const Attrs& attrs, const Array{topi::nn::global_pool(inputs[0], mode, layout.name())}; } -Expr MakeGlobalAvgPool2D(Expr data, String layout) { +Expr MakeGlobalAvgPool2D(Expr data, String layout, String out_layout) { auto attrs = make_object(); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.global_avg_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -331,9 +338,10 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") .set_attr("FTVMCompute", GlobalPool2DCompute); // GlobalMaxPool -Expr MakeGlobalMaxPool2D(Expr data, String layout) { +Expr MakeGlobalMaxPool2D(Expr data, String layout, String out_layout) { auto attrs = make_object(); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.global_max_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -423,10 +431,12 @@ Array AdaptivePool1DCompute(const Attrs& attrs, const Array output_size, String layout) { +Expr MakeAdaptiveAvgPool1D(Expr data, Array output_size, String layout, + String out_layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.adaptive_avg_pool1d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -456,10 +466,12 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool1d") .set_attr("FTVMCompute", AdaptivePool1DCompute); // relay.nn.adaptive_max_pool1d -Expr MakeAdaptiveMaxPool1D(Expr data, Array output_size, String layout) { +Expr MakeAdaptiveMaxPool1D(Expr data, Array output_size, String layout, + String out_layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.adaptive_max_pool1d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -571,10 +583,12 @@ Array AdaptivePool2DCompute(const Attrs& attrs, const Array output_size, String layout) { +Expr MakeAdaptiveAvgPool2D(Expr data, Array output_size, String layout, + String out_layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.adaptive_avg_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -606,10 +620,12 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool2d") .set_attr("FTVMCompute", AdaptivePool2DCompute); // relay.nn.adaptive_max_pool2d -Expr MakeAdaptiveMaxPool2D(Expr data, Array output_size, String layout) { +Expr MakeAdaptiveMaxPool2D(Expr data, Array output_size, String layout, + String out_layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.adaptive_max_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -700,6 +716,7 @@ Array AdaptivePool3DCompute(const Attrs& attrs, const Array(); ICHECK(param != nullptr); Layout layout(param->layout); + Layout out_layout(param->out_layout); ICHECK(tir::BijectiveLayout(layout, kNCDHW).defined()) << "Adaptive pool3d currently only supports layouts that are convertible from NCDHW"; ICHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1) @@ -737,10 +754,12 @@ Array AdaptivePool3DCompute(const Attrs& attrs, const Array output_size, String layout) { +Expr MakeAdaptiveMaxPool3D(Expr data, Array output_size, String layout, + String out_layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.adaptive_max_pool3d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -772,10 +791,12 @@ RELAY_REGISTER_OP("nn.adaptive_max_pool3d") .set_attr("FTVMCompute", AdaptivePool3DCompute); // relay.nn.adaptive_max_pool3d -Expr MakeAdaptiveAvgPool3D(Expr data, Array output_size, String layout) { +Expr MakeAdaptiveAvgPool3D(Expr data, Array output_size, String layout, + String out_layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.adaptive_avg_pool3d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -866,12 +887,13 @@ Array Pool2DGradCompute(const Attrs& attrs, const Array& // MaxPool2DGrad Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array pool_size, Array strides, Array padding, String layout, - bool ceil_mode) { + String out_layout, bool ceil_mode) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); attrs->ceil_mode = ceil_mode; static const Op& op = Op::Get("nn.max_pool2d_grad"); return Call(op, {out_grad, data}, Attrs(attrs), {}); @@ -913,12 +935,13 @@ RELAY_REGISTER_OP("nn.max_pool2d_grad") // AvgPool2DGrad Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array pool_size, Array strides, Array padding, String layout, - bool ceil_mode, bool count_include_pad) { + String out_layout, bool ceil_mode, bool count_include_pad) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); attrs->ceil_mode = ceil_mode; attrs->count_include_pad = count_include_pad; static const Op& op = Op::Get("nn.avg_pool2d_grad"); @@ -976,6 +999,7 @@ bool Pool1DRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK(param != nullptr); Layout layout(param->layout); + Layout out_layout(param->out_layout); ICHECK(layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('w'))) << "Invalid layout " << layout << ". Pool1D layout must have W, which cannot be split"; @@ -1018,6 +1042,7 @@ Array Pool1DCompute(const Attrs& attrs, const Array& inp auto padding = param->padding; auto ceil_mode = param->ceil_mode; Layout layout(param->layout); + Layout out_layout(param->out_layout); ICHECK(tir::BijectiveLayout(layout, kNCW).defined()) << "max_pool1d currently only supports layouts that are convertible from NCW"; @@ -1046,9 +1071,9 @@ Array Pool1DCompute(const Attrs& attrs, const Array& inp TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool1d") .set_body_typed([](Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode) { + String out_layout, bool ceil_mode) { return MakeMaxPool(data, pool_size, strides, dilation, padding, layout, - ceil_mode, "nn.max_pool1d"); + out_layout, ceil_mode, "nn.max_pool1d"); }); RELAY_REGISTER_OP("nn.max_pool1d") @@ -1082,9 +1107,9 @@ RELAY_REGISTER_OP("nn.max_pool1d") TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool1d") .set_body_typed([](Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode, bool count_include_pad) { + String out_layout, bool ceil_mode, bool count_include_pad) { return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, - ceil_mode, count_include_pad, "nn.avg_pool1d"); + out_layout, ceil_mode, count_include_pad, "nn.avg_pool1d"); }); RELAY_REGISTER_OP("nn.avg_pool1d") @@ -1134,6 +1159,7 @@ bool Pool3DRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK(param != nullptr); Layout layout(param->layout); + Layout out_layout(param->out_layout); ICHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) @@ -1194,6 +1220,7 @@ Array Pool3DCompute(const Attrs& attrs, const Array& inp auto padding = param->padding; auto ceil_mode = param->ceil_mode; Layout layout(param->layout); + Layout out_layout(param->out_layout); ICHECK(tir::BijectiveLayout(layout, kNCDHW).defined()) << "max_pool3d currently only supports layouts that are convertible from NCDHW"; @@ -1231,9 +1258,9 @@ Array Pool3DCompute(const Attrs& attrs, const Array& inp TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d") .set_body_typed([](Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode) { + String out_layout, bool ceil_mode) { return MakeMaxPool(data, pool_size, strides, dilation, padding, layout, - ceil_mode, "nn.max_pool3d"); + out_layout, ceil_mode, "nn.max_pool3d"); }); RELAY_REGISTER_OP("nn.max_pool3d") @@ -1270,9 +1297,9 @@ RELAY_REGISTER_OP("nn.max_pool3d") TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d") .set_body_typed([](Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode, bool count_include_pad) { + String out_layout, bool ceil_mode, bool count_include_pad) { return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, - ceil_mode, count_include_pad, "nn.avg_pool3d"); + out_layout, ceil_mode, count_include_pad, "nn.avg_pool3d"); }); RELAY_REGISTER_OP("nn.avg_pool3d") diff --git a/src/relay/op/nn/pooling.h b/src/relay/op/nn/pooling.h index 9b7eab25fe9a..32ae464101ab 100644 --- a/src/relay/op/nn/pooling.h +++ b/src/relay/op/nn/pooling.h @@ -35,13 +35,14 @@ namespace relay { template inline Expr MakeMaxPool(Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode, String op_name) { + String out_layout, bool ceil_mode, String op_name) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->dilation = std::move(dilation); attrs->padding = std::move(padding); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); attrs->ceil_mode = ceil_mode; static const Op& op = Op::Get(op_name); return Call(op, {data}, Attrs(attrs), {}); @@ -50,13 +51,14 @@ inline Expr MakeMaxPool(Expr data, Array pool_size, Array template inline Expr MakeAvgPool(Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode, bool count_include_pad, String op_name) { + String out_layout, bool ceil_mode, bool count_include_pad, String op_name) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->dilation = std::move(dilation); attrs->padding = std::move(padding); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); attrs->ceil_mode = ceil_mode; attrs->count_include_pad = count_include_pad; static const Op& op = Op::Get(op_name); diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index c9f14c91c7b1..5001925b7570 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -149,23 +149,41 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, tvm::Array new_r_axes; std::string inferred_in_string = ""; std::string inferred_out_string = ""; - int axis_index = 0; - for (auto iter_var : layout->axes) { - const auto& layout_axis = LayoutAxis::Get(iter_var); + auto push_new_axis = [&](const std::string& layout_dim, int axis) { + if ((old_r_dims.count(layout_dim) && !params->exclude) || + (!old_r_dims.count(layout_dim) && params->exclude)) { + new_r_axes.push_back(tvm::Integer(axis)); + return true; + } + return false; + }; + for (size_t axis_index = 0; axis_index < layout->axes.size(); ++axis_index) { + const auto& layout_axis = LayoutAxis::Get(layout->axes[axis_index]); const std::string& layout_dim = layout_axis.name(); - // Collect only the primal axis. if (layout_axis.IsPrimal()) { - if (old_r_dims.count(layout_dim) && !params->exclude) { - new_r_axes.push_back(tvm::Integer(axis_index)); - } - if (!old_r_dims.count(layout_dim) && params->exclude) { - new_r_axes.push_back(tvm::Integer(axis_index)); - } + push_new_axis(layout_dim, axis_index); + inferred_in_string += layout_dim; if (!old_r_dims.count(layout_dim) || params->keepdims) { inferred_out_string += layout_dim; } - inferred_in_string += layout_dim; - axis_index++; + } else { + // For example, if the original layout is NCHW, the new layout is NCHW8c, and the original + // reduce axes is [1], the new reduce axes become [1, 4]. + auto primal_dim = layout_axis.ToPrimal().name(); + auto packed_dim = std::to_string(layout.FactorOf(layout_axis)) + layout_dim; + inferred_in_string += packed_dim; + if (push_new_axis(primal_dim, axis_index)) { + if (params->exclude) { + // The primal axis is not reduced, so keep the input packed dim. + inferred_out_string += packed_dim; + } else { + // If the primal axis is part of reduce axes in the original layout, the inner dim + // becomes 1 after reduction. + inferred_out_string += "1" + layout_dim; + } + } else { + inferred_out_string += packed_dim; + } } } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3781107eeee1..90a0e3150573 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2599,24 +2599,19 @@ InferCorrectLayoutOutput StridedSliceInferCorrectLayout( params->strides = new_strides; layout = new_layout; } - } else { + } else if (old_layout_name.size() < + new_layout_name.size()) { // prohibit transforms such as NCHW4c -> NCHW if (params->axes) { auto axes = params->axes.value(); Array new_axes; - for (size_t i = 0; i < axes.size(); ++i) { auto old_idx = axes[i]; auto new_idx = new_layout.IndexOf(layout[old_idx]); new_axes.push_back(new_idx); const LayoutAxis& axis = layout[old_idx]; - if (!axis.IsPrimal()) { - // original layout that contains splitted axes is not supported - return out_default; - } - + ICHECK(axis.IsPrimal()); auto factor = new_layout.FactorOf(axis); - if (factor == -1) { new_begin.push_back(begin[i]); new_end.push_back(end[i]); @@ -2636,10 +2631,7 @@ InferCorrectLayoutOutput StridedSliceInferCorrectLayout( } else { for (size_t i = 0; i < begin.size(); i++) { const LayoutAxis& axis = layout[i]; - if (!axis.IsPrimal()) { - // original layout that contains splitted axes is not supported - return out_default; - } + ICHECK(axis.IsPrimal()); auto factor = new_layout.FactorOf(axis); if (factor == -1) { new_begin.push_back(IntImm(begin[i]->dtype, begin[i])); @@ -3260,8 +3252,10 @@ bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs, oshape.reserve(ndim_data); for (size_t i = 0; i < ndim_data; ++i) { if (i == static_cast(axis)) { - const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]); - ICHECK_GE(*indice_shape_i, 1); + if (indices->shape[i].as()) { + const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]); + ICHECK_GE(*indice_shape_i, 1); + } } else { ICHECK(reporter->AssertEQ(indices->shape[i], data->shape[i])); } diff --git a/src/relay/op/vm/vm.cc b/src/relay/op/vm/vm.cc index be31b5482937..65a4ec01805b 100644 --- a/src/relay/op/vm/vm.cc +++ b/src/relay/op/vm/vm.cc @@ -138,12 +138,7 @@ RELAY_REGISTER_OP("vm.shape_func") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("TNonComputational", true) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); // vm.invoke_tvm_op bool InvokeTVMOpRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -188,12 +183,7 @@ RELAY_REGISTER_OP("vm.invoke_tvm_op") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("TNonComputational", true) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); // vm.reshape TVM_REGISTER_NODE_TYPE(ReshapeTensorAttrs); diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 5782f1f6b4d1..ecdd36ddb791 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -275,6 +275,7 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_ Array padding({0, 0}); reduced_t2 = AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, param->dilation, padding, param->data_layout, + "", // out_layout false, // ceil_mode false); // count_include_pad } else { @@ -284,6 +285,7 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_ Array padding({0, 0}); reduced_t2 = AvgPool2D(reduced_t2, param->kernel_size, param->strides, param->dilation, padding, param->data_layout, + "", // out_layout false, // ceil_mode false); // count_include_pad } @@ -463,6 +465,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, Multiply(reduced_c_t2, MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w)); reduced_t2 = AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, param->dilation, padding, param->data_layout, + "", // out_layout false, // ceil_mode false); // count_include_pad } else { @@ -471,6 +474,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, if (stride1 * stride2 != 1) { reduced_t2 = AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, param->dilation, padding, param->data_layout, + "", // out_layout false, // ceil_mode false); // count_include_pad } diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index 7a86af8aeffa..c538dac048b3 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -34,7 +34,7 @@ #include #include -#include "../backend/compile_engine.h" +#include "../backend/te_compiler.h" #include "pattern_utils.h" namespace tvm { @@ -126,7 +126,8 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - PrimFuncFor(GetRef(func), Target::Current(), [](std::string name) { return name; }); + tec::PrimFuncFor(GetRef(func), Target::Current(), + [](std::string name) { return name; }); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index 14a86bc8d080..5255a672a856 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -288,8 +288,8 @@ class DefuncMutator : public ExprMutator { return Call(c, call_args); } - - throw std::runtime_error("EncodeArg failed to cast arg into identifier node or function node"); + LOG(FATAL) << "EncodeArg failed to cast arg into identifier node or function node"; + return {}; } /*! diff --git a/src/relay/transforms/pass_utils.h b/src/relay/transforms/pass_utils.h index ed9409856871..fd7f0a5594c2 100644 --- a/src/relay/transforms/pass_utils.h +++ b/src/relay/transforms/pass_utils.h @@ -118,9 +118,8 @@ inline Expr TransformF(const std::function& func, const Expr& * is it atomic? * if so, the compute cost of the expression is bounded so it can be copy without graph mode. */ -inline bool IsAtomic(const Expr& e) { - auto props = GetOnDeviceProps(e); - Expr true_expr = props.body.defined() ? props.body : e; +inline bool IsAtomic(const Expr& expr) { + Expr true_expr = IgnoreOnDevice(expr); return true_expr.as() || true_expr.as() || true_expr.as() || true_expr.as() || true_expr.as(); // Constant is always by reference. diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 692ef3c9f557..03b8ee6937a7 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -676,9 +676,10 @@ static inline Expr Reshape(Expr data, Array newshape) { static inline Expr AvgPool2D(Expr data, Array pool_size, Array strides, Array dilation, Array padding, - std::string layout, bool ceil_mode, bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool2d"); + std::string layout, std::string out_layout, bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, + out_layout, ceil_mode, count_include_pad, "nn.avg_pool2d"); } static inline Expr Pad(Expr data, Array> pad_width, Expr pad_value, diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 5ca6d86b1d52..6d74e48e871e 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -666,7 +666,7 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Type checked_type = solver_->Resolve(it->second.checked_type); if (checked_type.as() != nullptr) { - this->solver_->diag_ctx_.Emit( + this->solver_->Emit( Diagnostic::Error(op->span) << "The type inference pass was unable to infer a type for this expression.\n" << "This usually occurs when an operator call is under constrained in some way," diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 5bbc536afaca..a336cf494f4b 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -499,7 +499,7 @@ class ACLRuntime : public JSONRuntimeBase { layer->outputs.push_back( MakeACLTensorFromJSONNode(node, &node.GetInputs()[6], &node.GetInputs()[7])); } else { - throw std::runtime_error("Unsupported form of add op: " + op_name); + LOG(FATAL) << "Unsupported form of add op: " + op_name; } auto f = std::make_shared(); diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 94bbae1559d9..a27fe1114af9 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -226,6 +226,55 @@ class ElementWiseBinaryOpConverter : public TensorRTOpConverter { } }; +class Conv1DOpConverter : public TensorRTOpConverter { + public: + Conv1DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + auto weight_shape = params->inputs.at(1).weight_shape; + ICHECK_EQ(params->node.GetAttr>("data_layout")[0], "NCW"); + ICHECK_EQ(params->node.GetAttr>("kernel_layout")[0], "OIW"); + auto str_strides = params->node.GetAttr>("strides"); + auto str_dilation = params->node.GetAttr>("dilation"); + auto str_padding = params->node.GetAttr>("padding"); + int groups = std::stoi(params->node.GetAttr>("groups")[0]); + int channels = weight_shape[0]; + if (params->node.HasAttr("channels") && + !params->node.GetAttr>("channels")[0].empty()) { + channels = std::stoi(params->node.GetAttr>("channels")[0]); + } + + auto shuffle_layer = params->network->addShuffle(*input_tensor); + std::vector new_shape = {input_dims[0], input_dims[1], 1}; + shuffle_layer->setReshapeDimensions(VectorToTrtDims(new_shape)); + input_tensor = shuffle_layer->getOutput(0); + + const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], 1); + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + + auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size, + params->inputs.at(1).weight, bias); + ICHECK(conv_layer != nullptr); + conv_layer->setPadding(nvinfer1::DimsHW(std::stoi(str_padding[0]), 0)); + ICHECK_EQ(str_strides.size(), 1); + const auto strides = nvinfer1::DimsHW(std::stoi(str_strides[0]), 1); + conv_layer->setStride(strides); + ICHECK_EQ(str_dilation.size(), 1); + const auto dilation = nvinfer1::DimsHW(std::stoi(str_dilation[0]), 1); + conv_layer->setDilation(dilation); + conv_layer->setNbGroups(groups); + input_tensor = conv_layer->getOutput(0); + + auto conv_output_dims = TrtDimsToVector(input_tensor->getDimensions()); + std::vector back_shape = {0, 0}; + auto shuffle_back_layer = params->network->addShuffle(*input_tensor); + shuffle_back_layer->setReshapeDimensions(VectorToTrtDims(back_shape)); + params->outputs.push_back(shuffle_back_layer->getOutput(0)); + } +}; + class Conv2DOpConverter : public TensorRTOpConverter { public: Conv2DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} @@ -1198,6 +1247,7 @@ GetOpConverters() { map->emplace("nn.batch_norm", std::make_shared()); map->emplace("nn.layer_norm", std::make_shared()); map->emplace("nn.softmax", std::make_shared()); + map->emplace("nn.conv1d", std::make_shared()); map->emplace("nn.conv2d", std::make_shared()); map->emplace("nn.dense", std::make_shared()); map->emplace("nn.bias_add", std::make_shared()); diff --git a/src/runtime/crt/crt_config-template.h b/src/runtime/crt/crt_config-template.h index aa718a303744..90897a9542b6 100644 --- a/src/runtime/crt/crt_config-template.h +++ b/src/runtime/crt/crt_config-template.h @@ -49,7 +49,10 @@ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 + +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 diff --git a/src/runtime/crt/graph_executor/graph_executor.c b/src/runtime/crt/graph_executor/graph_executor.c index 34e81c7d33b1..3fea408d9760 100644 --- a/src/runtime/crt/graph_executor/graph_executor.c +++ b/src/runtime/crt/graph_executor/graph_executor.c @@ -77,7 +77,7 @@ int NodeEntry_Load(TVMGraphExecutorNodeEntry* entry, JSONReader* reader) { void TVMGraphExecutorNode_LoadAttrs(TVMGraphExecutorNode* node, JSONReader* reader, TVMOpParam* param) { int bitmask = 0; - char key[20], value[120]; + char key[20], value[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; memset(param, 0, sizeof(TVMOpParam)); memset(key, 0, sizeof(key)); memset(value, 0, sizeof(value)); @@ -796,13 +796,13 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl char* names = NULL; DLDevice dev = {kDLCPU, 0}; tvm_crt_error_t err = TVMPlatformMemoryAllocate( - TVM_CRT_MAX_STRLEN_FUNCTION_NAME * executor->nodes_count, dev, (void**)&names); + TVM_CRT_MAX_STRLEN_PARAM_NAME * executor->nodes_count, dev, (void**)&names); if (err != kTvmErrorNoError) { fprintf(stderr, "memory allocate error: %08x", err); status = -1; return status; } - memset(names, 0, TVM_CRT_MAX_STRLEN_FUNCTION_NAME * executor->nodes_count); + memset(names, 0, TVM_CRT_MAX_STRLEN_PARAM_NAME * executor->nodes_count); uint64_t names_count; int idx; memcpy(&names_count, bptr, sizeof(names_count)); @@ -811,11 +811,11 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl uint64_t name_length; memcpy(&name_length, bptr, sizeof(name_length)); bptr += sizeof(name_length); - if (name_length >= TVM_CRT_MAX_STRLEN_FUNCTION_NAME) { + if (name_length >= TVM_CRT_MAX_STRLEN_PARAM_NAME) { fprintf(stderr, "Error: function name longer than expected.\n"); status = -1; } - memcpy(names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx, bptr, name_length); + memcpy(names + TVM_CRT_MAX_STRLEN_PARAM_NAME * idx, bptr, name_length); bptr += name_length; } @@ -831,9 +831,9 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl for (idx = 0; idx < size; idx++) { int32_t in_idx = - TVMGraphExecutor_GetInputIndex(executor, names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx); + TVMGraphExecutor_GetInputIndex(executor, names + TVM_CRT_MAX_STRLEN_PARAM_NAME * idx); CHECK_GT(in_idx, 0, "Found param for non-existent input: %s\n", - names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx); + names + TVM_CRT_MAX_STRLEN_PARAM_NAME * idx); uint32_t eid = TVMGraphExecutor_GetEntryId(executor, executor->input_nodes[in_idx], 0); if (!(eid < executor->data_entry_count)) { fprintf(stderr, "`entry_id`=%d is greater than expected(%d).\n", eid, @@ -859,7 +859,7 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl #if TVM_CRT_DEBUG TVMNDArray* entry = &(executor->data_entry[eid]); printf("loading: param %s loaded, in_idx=%d, eid=%d, ndim=%d, data[0]=%f\n", - names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx, in_idx, eid, entry->dl_tensor.ndim, + names + TVM_CRT_MAX_STRLEN_PARAM_NAME * idx, in_idx, eid, entry->dl_tensor.ndim, ((float*)entry->dl_tensor.data)[0]); // NOLINT(*) #endif // TVM_CRT_DEBUG } @@ -1181,13 +1181,6 @@ int TVMGraphExecutor_Init(TVMGraphExecutor* executor, const char* graph_json, return status; } status = TVMGraphExecutor_SetupOpExecs(executor); - if (status != 0) { - if (status != 0) { - return status; - } - - return status; - } return status; } diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h index c67c43357363..d4429308b650 100644 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h +++ b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h @@ -60,7 +60,7 @@ typedef struct TVMGraphExecutorNode { // operator type in string char op_type[16]; // name of the op - char name[120]; + char name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; // parameters TVMOpParam param; // inputs diff --git a/src/runtime/hexagon/launcher/CMakeLists.txt b/src/runtime/hexagon/launcher/CMakeLists.txt deleted file mode 100644 index d3a2f4f8161d..000000000000 --- a/src/runtime/hexagon/launcher/CMakeLists.txt +++ /dev/null @@ -1,156 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -cmake_minimum_required(VERSION 3.2) -project(HexagonLauncher C CXX) - -if(NOT "${FASTRPC_LIBS}" STREQUAL "SKEL" AND - NOT "${FASTRPC_LIBS}" STREQUAL "STUB") - message(SEND_ERROR "Please set FASTRPC_LIBS to either SKEL or STUB") -endif() - -if(NOT DEFINED USE_HEXAGON_SDK) - message(SEND_ERROR "Please set USE_HEXAGON_SDK to the location of Hexagon SDK") -endif() -if (NOT DEFINED USE_HEXAGON_ARCH) - message(SEND_ERROR "Please set USE_HEXAGON_ARCH to the Hexagon architecture version") -endif() - -include(../../../../cmake/modules/HexagonSDK.cmake) - -find_hexagon_sdk_root("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}") - -include_directories(SYSTEM ${HEXAGON_SDK_INCLUDES} ${HEXAGON_REMOTE_ROOT}) - -set(QAIC_EXE "${HEXAGON_QAIC_EXE}") -foreach(INCDIR IN LISTS HEXAGON_SDK_INCLUDES HEXAGON_REMOTE_ROOT) - list(APPEND QAIC_FLAGS "-I${INCDIR}") -endforeach() - -set(LAUNCHER_SRC "${CMAKE_CURRENT_SOURCE_DIR}") -set(CMAKE_SKIP_RPATH TRUE) - -# Qaic for the domain header. -# -# Don't add paths to these filenames, or otherwise cmake may spontaneously -# add -o option to the qaic invocation (with an undesirable path). -set(LAUNCHER_RPC_IDL "launcher_rpc.idl") -set(LAUNCHER_RPC_H "launcher_rpc.h") -set(LAUNCHER_RPC_SKEL_C "launcher_rpc_skel.c") -set(LAUNCHER_RPC_STUB_C "launcher_rpc_stub.c") - -add_custom_command( - OUTPUT ${LAUNCHER_RPC_SKEL_C} ${LAUNCHER_RPC_STUB_C} - "${LAUNCHER_SRC}/${LAUNCHER_RPC_H}" - COMMAND ${QAIC_EXE} ${QAIC_FLAGS} - "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" - COMMAND ${CMAKE_COMMAND} -E rename "${LAUNCHER_RPC_H}" - "${LAUNCHER_SRC}/${LAUNCHER_RPC_H}" - MAIN_DEPENDENCY "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" -) - - -if("${FASTRPC_LIBS}" STREQUAL "SKEL") - # Skel libraries. - # - if (NOT DEFINED TVM_RUNTIME_HEXAGON) - message(SEND_ERROR "Please set TVM_RUNTIME_HEXAGON=/path/to/libtvm_runtime.a") - endif() - - include_directories(SYSTEM ${HEXAGON_QURT_INCLUDES}) - include_directories( - "${LAUNCHER_SRC}" - "${LAUNCHER_SRC}/../../../../include" - "${LAUNCHER_SRC}/../../../../3rdparty/dlpack/include" - "${LAUNCHER_SRC}/../../../../3rdparty/dmlc-core/include" - ) - link_directories(${HEXAGON_QURT_LIBS}) - - add_definitions(-D_MACH_I32=int) - add_definitions(-DDMLC_CXX11_THREAD_LOCAL=0) - add_definitions(-DDMLC_USE_LOGGING_LIBRARY=) - - # Extra compile flags (both C and C++). - set(EXTRA_COMP_FLAGS - "-O3" - "-m${USE_HEXAGON_ARCH}" - ) - string(REGEX REPLACE ";" " " EXTRA_COMP_FLAGS_STR "${EXTRA_COMP_FLAGS}") - set(CMAKE_C_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_C_FLAGS}") - set(CMAKE_CXX_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_CXX_FLAGS}") - - set(EXTRA_LINK_FLAGS - "-lposix" - "-lqurt" - "-Wl,--export-dynamic" - "-Wl,--whole-archive ${TVM_RUNTIME_HEXAGON} -Wl,--no-whole-archive" - "-Wl,--defsym=HEAP_SIZE=0x40000000" - ) - string(REGEX REPLACE ";" " " EXTRA_LINK_FLAGS_STR "${EXTRA_LINK_FLAGS}") - - set(SKEL_SRCS - "launcher_core.cc" - "launcher_hexagon.cc" - ) - add_library(launcher_rpc_skel SHARED - "${LAUNCHER_SRC}/${LAUNCHER_RPC_H}" - "${LAUNCHER_RPC_SKEL_C}" - "${SKEL_SRCS}" - ) - - # Extra linker flags for linking shared libraries. - set_target_properties(launcher_rpc_skel PROPERTIES - LINK_FLAGS ${EXTRA_LINK_FLAGS_STR} - ) -else() - # Stub libraries. - # - if (NOT DEFINED TVM_RUNTIME_ANDROID) - message(SEND_ERROR "Please set TVM_RUNTIME_ANDROID=/path/to/libtvm_runtime.so") - endif() - - include_directories(SYSTEM - "${HEXAGON_SDK_INCLUDES}" - "${HEXAGON_RPCMEM_ROOT}/inc" - ) - include_directories( - "${LAUNCHER_SRC}" - "${LAUNCHER_SRC}/../../../../include" - "${LAUNCHER_SRC}/../../../../3rdparty/dlpack/include" - "${LAUNCHER_SRC}/../../../../3rdparty/dmlc-core/include" - ) - link_directories(${HEXAGON_REMOTE_ROOT}) - - add_definitions(-DDMLC_USE_LOGGING_LIBRARY=) - - set(STUB_SRCS - "launcher_android.cc" - "launcher_core.cc" - "launcher_main.cc" - "launcher_util.cc" - ) - - add_executable(launcher_android - "${STUB_SRCS}" - "${LAUNCHER_RPC_STUB_C}" - ) - target_link_libraries(launcher_android cdsprpc log) - - set_target_properties(launcher_android PROPERTIES - LINK_FLAGS "${TVM_RUNTIME_ANDROID}" - ) -endif() diff --git a/src/runtime/micro/crt_config.h b/src/runtime/micro/crt_config.h index c3e8fea1ba08..602060de1b4a 100644 --- a/src/runtime/micro/crt_config.h +++ b/src/runtime/micro/crt_config.h @@ -37,7 +37,9 @@ /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 3cd5df613f4a..4e24434642d8 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -41,7 +41,7 @@ namespace runtime { struct TypeInfo { /*! \brief The current index. */ uint32_t index{0}; - /*! \brief Index of the parent in the type hierachy */ + /*! \brief Index of the parent in the type hierarchy */ uint32_t parent_index{0}; // NOTE: the indices in [index, index + num_reserved_slots) are // reserved for the child-class of this type. @@ -58,7 +58,7 @@ struct TypeInfo { }; /*! - * \brief Type context that manages the type hierachy information. + * \brief Type context that manages the type hierarchy information. */ class TypeContext { public: diff --git a/src/runtime/pipeline/pipeline_executor.cc b/src/runtime/pipeline/pipeline_executor.cc index 41f867057282..3820ce942af0 100644 --- a/src/runtime/pipeline/pipeline_executor.cc +++ b/src/runtime/pipeline/pipeline_executor.cc @@ -21,31 +21,129 @@ * \file pipeline_executor.cc */ #include "pipeline_executor.h" - namespace tvm { namespace runtime { +/*! + * \brief Give frontends an access to packed functions. + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding packed function. + */ +PackedFunc PipelineExecutor::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + if (name == "get_num_outputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc(); + } + return nullptr; +} -void PipelineRuntime::Init(const Array& modules, - const std::string& pipeline_json) { - return; +/*! + * \brief Use the mod_config information to create a graph runtime list. + * \param mod_config The config information that generates by the export library function call. + */ +std::vector PipelineExecutor::CreateGraphModules(const ModuleConfig& mod_config) { + const PackedFunc* graph_executor_create = Registry::Get("tvm.graph_executor.create"); + std::vector ret; + ret.resize(mod_config.size()); + for (auto config : mod_config) { + // Load library. + auto lib = Module::LoadFromFile(config.second.lib_name.c_str()); + + // Read json. + std::ifstream ifJson(config.second.json_name.c_str()); + if (ifJson.fail()) { + LOG(FATAL) << "json file not found: " << config.second.json_name; + } + const std::string json((std::istreambuf_iterator(ifJson)), + std::istreambuf_iterator()); + + // Create a graph executor. + std::istringstream istr(config.second.dev); + std::string str; + int device_type = 1, device_id = 0; + while (getline(istr, str, ';')) { + std::istringstream istr_dev(str); + std::string str_temp; + if (getline(istr_dev, str_temp)) { + device_type = stoi(str_temp); + } + if (getline(istr_dev, str_temp)) { + device_id = stoi(str_temp); + } + } + Module graph_module = (*graph_executor_create)(json, lib, device_type, device_id); + + // Load parameters. + TVMByteArray params_arr; + const char* params_file_name = config.second.params_name.c_str(); + std::ifstream if_param(params_file_name); + if (if_param.fail()) { + LOG(FATAL) << "params file not found: " << params_file_name; + } + const std::string params((std::istreambuf_iterator(if_param)), + std::istreambuf_iterator()); + params_arr.data = params.c_str(); + params_arr.size = params.length(); + auto load_params = graph_module.GetFunction("load_params"); + load_params(params_arr); + + // Put a graph executor module into the vector. + ret[config.first] = graph_module; + } + return ret; } -/* GetFunction can not be pure abstract function, implement an empty function for now. +/*! + * \brief Initialize the pipeline executor with a list of modules to be pipelined + * and config in JSON format. + * \param modules The module list used for building the pipeline. + * \param pipeline_json The configuration of modules dependencies. */ -PackedFunc PipelineRuntime::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { - return nullptr; +void PipelineExecutor::Init(const std::vector& modules, const std::string& pipeline_json) { + ICHECK(!modules.empty()) << "The graph executor module list is empty."; + // Use JSONReader to load pipeline configuration. + std::istringstream is(pipeline_json); + dmlc::JSONReader reader(&is); + PipelineConfig& pipeline_config = this->LoadPipelineConfig(&reader); + ICHECK(!pipeline_config.Empty()) << "The pipeline config information is empty."; + // Initialize the pipeline function class used for pipeline thread pool management + // and schedule etc. This function returns the number of output. + num_outputs_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config); + return; } -Module PipelineRuntimeCreate(const Array& m, - const std::string& pipeline_json) { - auto exec = make_object(); - exec->Init(m, pipeline_json); +Module PipelineExecutorCreate(const Array& m, const std::string& pipeline_json) { + ICHECK(!m.empty()) << "The module list is empty."; + auto exec = make_object(); + std::vector graph_modules; + for (auto mod : m) { + graph_modules.push_back(mod); + } + exec->Init(graph_modules, pipeline_json); + return Module(exec); +} + +Module PipelineExecutorLoad(const std::string& load_json, const std::string& pipeline_json) { + auto exec = make_object(); + std::istringstream is(load_json); + dmlc::JSONReader reader(&is); + ModuleConfig& mod_config = exec->LoadModuleConfig(&reader); + ICHECK(!mod_config.empty()) << "The module config is empty."; + std::vector modules = exec->CreateGraphModules(mod_config); + exec->Init(modules, pipeline_json); return Module(exec); } TVM_REGISTER_GLOBAL("tvm.pipeline_executor.create").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = PipelineRuntimeCreate(args[0], args[1]); + *rv = PipelineExecutorCreate(args[0], args[1]); +}); + +TVM_REGISTER_GLOBAL("tvm.pipeline_executor.load").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = PipelineExecutorLoad(args[0], args[1]); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/pipeline/pipeline_executor.h b/src/runtime/pipeline/pipeline_executor.h index c7625c62b724..a883ba25ec08 100644 --- a/src/runtime/pipeline/pipeline_executor.h +++ b/src/runtime/pipeline/pipeline_executor.h @@ -23,9 +23,16 @@ */ #ifndef TVM_RUNTIME_PIPELINE_PIPELINE_EXECUTOR_H_ #define TVM_RUNTIME_PIPELINE_PIPELINE_EXECUTOR_H_ + #include +#include +#include +#include #include +#include + +#include "pipeline_scheduler.h" namespace tvm { namespace runtime { /*! @@ -36,18 +43,23 @@ namespace runtime { * * This executor can be accessed by various language via TVM runtime PackedFunc API. */ -class TVM_DLL PipelineRuntime : public ModuleNode { +class TVM_DLL PipelineExecutor : public ModuleNode { public: /*! * \Return the type key of the executor. */ - const char* type_key() const final { return "PipelineRuntime"; } + const char* type_key() const final { return "PipelineExecutor"; } /*! - * \brief Initialize the pipeline executor with module array and json text. + * \brief Initialize the pipeline executor with module array and JSON text. * \param modules The module list used for building pipeline. * \param pipeline_json The configuration of modules dependencies. */ - void Init(const Array& modules, const std::string& pipeline_json); + void Init(const std::vector& modules, const std::string& pipeline_json); + /*! + * \brief Use the information of mod_config to create a list of graph executor. + * \param mod_config The configuration information generated by the library export function call. + */ + std::vector CreateGraphModules(const ModuleConfig& mod_config); /*! * \brief Give frontends an access to packed functions. * \param name The name of the function. @@ -55,6 +67,86 @@ class TVM_DLL PipelineRuntime : public ModuleNode { * \return The corresponding packed function. */ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); + + /*! + * \brief Get the number of outputs. + * + * \return The number of outputs. + */ + int NumOutputs() const { return num_outputs_; } + + /*!\brief Load the module files information.*/ + ModuleConfig& LoadModuleConfig(dmlc::JSONReader* reader) { + reader->BeginArray(); + while (reader->NextArrayItem()) { + std::string key; + reader->BeginObject(); + int mod_idx = -1; + std::string lib_name; + std::string json_name; + std::string params_name; + std::string dev; + while (reader->NextObjectItem(&key)) { + if (key == "mod_idx") { + reader->Read(&mod_idx); + } else if (key == "lib_name") { + reader->Read(&lib_name); + } else if (key == "json_name") { + reader->Read(&json_name); + } else if (key == "params_name") { + reader->Read(¶ms_name); + } else if (key == "dev") { + reader->Read(&dev); + } else { + LOG(FATAL) << "do not support key " << key; + } + } + ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx; + // Load the lib, json, and params information. + ICHECK(!lib_name.empty()) << "lib_name is empty."; + ICHECK(!json_name.empty()) << "json_name is empty."; + ICHECK(!params_name.empty()) << "params_name is empty."; + mod_config_[mod_idx] = GraphModuleLoadInfo(lib_name, json_name, params_name, dev); + } + return mod_config_; + } + + private: + /*!\brief The class used to execute and schedule the pipeline logic.*/ + PipelineScheduler pipeline_scheduler_; + /*!\brief The dependency information of each graph runtime module of the pipeline.*/ + PipelineConfig pipeline_config_; + /*!\brief The module information used to create the graph runtimes.*/ + ModuleConfig mod_config_; + /*!\brief How many outputs are in this pipeline executor.*/ + size_t num_outputs_ = 0; + /*!\brief Json loader.*/ + PipelineConfig& LoadPipelineConfig(dmlc::JSONReader* reader) { + reader->BeginArray(); + while (reader->NextArrayItem()) { + std::string key; + reader->BeginObject(); + int mod_idx = -1; + OutputMap output; + std::string dev; + while (reader->NextObjectItem(&key)) { + if (key == "mod_idx") { + reader->Read(&mod_idx); + } else if (key == "dev") { + reader->Read(&dev); + } else if (key == "output") { + reader->Read(&output); + } else { + LOG(FATAL) << "do not support key " << key; + } + } + ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx; + // Check if the output is successfully read. + ICHECK(!output.Empty()) << "Invalid output binding result."; + pipeline_config_.Insert(mod_idx, output); + } + return pipeline_config_; + } }; } // namespace runtime } // namespace tvm diff --git a/src/runtime/pipeline/pipeline_scheduler.cc b/src/runtime/pipeline/pipeline_scheduler.cc new file mode 100644 index 000000000000..82caf855a479 --- /dev/null +++ b/src/runtime/pipeline/pipeline_scheduler.cc @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "pipeline_scheduler.h" + +#include +#include +namespace tvm { +namespace runtime { +/*! + * \brief Initialize the pipeline. + * \param modules The list of graph executor modules. + * \param pipeline_conf The dependency information of each graph executor module. + */ +size_t PipelineScheduler::PipelineInit(const std::vector& modules, + const PipelineConfig& pipeline_config) { + graph_modules_ = modules; + int num_output = pipeline_config.GetGlobalOutputNum(); + return num_output; +} +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/pipeline/pipeline_scheduler.h b/src/runtime/pipeline/pipeline_scheduler.h new file mode 100644 index 000000000000..5ee127edffa3 --- /dev/null +++ b/src/runtime/pipeline/pipeline_scheduler.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_PIPELINE_PIPELINE_SCHEDULER_H_ +#define TVM_RUNTIME_PIPELINE_PIPELINE_SCHEDULER_H_ +#include +#include +#include + +#include +#include +#include +#include + +#include "pipeline_struct.h" +namespace tvm { +namespace runtime { +/*! + * \brief The class that executes the pipeline logic,it is used to initialize the thread pool, + execute and schedule pipeline tasks, allocate and manage memory, etc. + */ +class PipelineScheduler { + public: + /*! + * \brief Initialize the pipeline. + * \param modules The list of graph executor module. + * \param pipeline_config The dependency information of each graph executor module. + */ + size_t PipelineInit(const std::vector& modules, const PipelineConfig& pipeline_config); + + private: + /*!\brief The list of graph executors.*/ + std::vector graph_modules_; +}; +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_PIPELINE_PIPELINE_SCHEDULER_H_ diff --git a/src/runtime/pipeline/pipeline_struct.h b/src/runtime/pipeline/pipeline_struct.h new file mode 100644 index 000000000000..3cc9621702c1 --- /dev/null +++ b/src/runtime/pipeline/pipeline_struct.h @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_PIPELINE_PIPELINE_STRUCT_H_ +#define TVM_RUNTIME_PIPELINE_PIPELINE_STRUCT_H_ +#include +#include +#include + +#include +#include +#include +#include +/*! + * \brief All binding information of a output interface. + */ +struct OutputBindings { + /*!\brief Output interface binding information, 'int' is the index of the module that + * uses this output data as the input interface data, 'string' is the input interface name + * of the module. + */ + std::unordered_map bindings; + /*! The index value of the global interface to which the current output are bound.*/ + int global_output_index = std::numeric_limits::min(); + /*!\brief Whether this binding is bound to the PipelineExecutor output interface.*/ + bool IsGlobalOutput() const { return global_output_index >= 0; } + /*! + * \brief Create a module interface map from JSONReader. + * \param reader JSON reader. + */ + void Load(dmlc::JSONReader* reader) { + reader->BeginArray(); + while (reader->NextArrayItem()) { + std::string key; + reader->BeginObject(); + std::string input_name; + int mod_idx = std::numeric_limits::min(); + // Whether the output binding is global. + bool global_binding = false; + while (reader->NextObjectItem(&key)) { + if (key == "mod_idx") { + reader->Read(&mod_idx); + } else if (key == "input_name") { + reader->Read(&input_name); + } else if (key == "global_output_index") { + // There should be only one global binding. + ICHECK(global_output_index < 0); + reader->Read(&global_output_index); + // When the key value is 'global_output_index', it means that this output is bound to + // a global interface. + global_binding = true; + } else { + LOG(FATAL) << "do not support key " << key; + } + } + // When this output is bound to a global interface, check if the global interface index + // start from 0. + if (global_binding) { + ICHECK(global_output_index >= 0); + } else { + // When this output is bound to a graph executor module interface, check if the module + // index start from 0. + ICHECK(mod_idx >= 0); + bindings[mod_idx] = input_name; + } + } + } +}; + +/*! + * \brief The binding information of all outputs of a module. + */ +struct OutputMap { + /*! \brief Output binding map, 'int' is output interface index.*/ + std::unordered_map output_binding_map; + OutputMap& operator=(const OutputMap& output) { + output_binding_map = output.output_binding_map; + return *this; + } + + /*!\brief This function is used to verify whether OutputMap is successfully loaded. + * \return Return true to indicate that this class has not been successfully loaded. + */ + bool Empty() { return output_binding_map.empty(); } + /*! \brief The pipeline outputs is the final outputs of pipeline, this function is used to + * get how many pipeline outputs are in this Outputmap + * \return Number of pipeline outputs. + */ + size_t GetGlobalOutputNum(void) const { + size_t num_output = 0; + for (auto bindings : output_binding_map) { + num_output += bindings.second.IsGlobalOutput() ? 1 : 0; + } + return num_output; + } + + /*! + * \brief Create a output binding map from JSONReader. + * \param reader Json reader. + */ + void Load(dmlc::JSONReader* reader) { + reader->BeginArray(); + while (reader->NextArrayItem()) { + std::string key; + reader->BeginObject(); + int output_idx = -1; + OutputBindings binding; + while (reader->NextObjectItem(&key)) { + if (key == "output_idx") { + reader->Read(&output_idx); + } else if (key == "dependencies") { + reader->Read(&binding); + } else { + LOG(FATAL) << "do not support key " << key; + } + } + ICHECK(output_idx >= 0); + output_binding_map[output_idx] = binding; + } + } +}; +/*! + * \brief The binding or dependency information of each module output interface. + */ +struct PipelineConfig { + /*!\brief The key is the module index, this variable records all module pipeline configuration + * information. + */ + std::unordered_map config; + OutputMap& operator[](int key) { + ICHECK(config.find(key) != config.end()); + return config[key]; + } + + void Insert(int key, const OutputMap& map) { config[key] = map; } + + /*!\brief This function is used to verify whether config is loaded successfully. + * \return Return true to indicate that this class has not been successfully loaded. + */ + bool Empty() { return config.empty(); } + + /*! + * \brief Get the number of global outputs. + * \return The number of outputs the entire pipeline has. + */ + size_t GetGlobalOutputNum() const { + size_t num_output = 0; + for (auto mod_output : config) { + num_output += mod_output.second.GetGlobalOutputNum(); + } + return num_output; + } +}; +/*! + * \brief The information used to initialize the graph executor module, the information + * come from the export library function call. + */ +struct GraphModuleLoadInfo { + GraphModuleLoadInfo(const std::string& lib, const std::string& json, const std::string& params, + const std::string& device) + : lib_name(lib), json_name(json), params_name(params), dev(device) {} + GraphModuleLoadInfo() { ; } + std::string lib_name; + std::string json_name; + std::string params_name; + std::string dev; +}; +/*! The Module information of each module.The 'int' is module index. */ +using ModuleConfig = std::unordered_map; +#endif // TVM_RUNTIME_PIPELINE_PIPELINE_STRUCT_H_ diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index bd59be87f7d9..90d4ac64238f 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -25,10 +25,12 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -160,6 +162,51 @@ void Profiler::Stop() { } } +std::vector ToShape(NDArray shape_tensor) { + std::vector shape; + auto rank = shape_tensor.Shape().size(); + auto dtype = shape_tensor.DataType(); + + // For 0-rank shapes we need to allocate a single scalar. + if (rank == 0) { + return shape; + } + + // Otherwise we should be rank-1, and we will extract the number of dimensions + // for the output vector. + ICHECK_EQ(rank, 1U) << "shape tensor should be a k-length vector, found " << rank; + int64_t ndim = shape_tensor.Shape().at(0); + shape.resize(ndim); + + const DLTensor* dl_tensor = shape_tensor.operator->(); + if (dtype.is_int() && dtype.bits() == 32 && dtype.lanes() == 1) { + int32_t* dims = reinterpret_cast(dl_tensor->data); + shape.assign(dims, dims + ndim); + } else if (dtype.is_int() && dtype.bits() == 64 && dtype.lanes() == 1) { + int64_t* dims = reinterpret_cast(dl_tensor->data); + shape.assign(dims, dims + ndim); + } else { + LOG(FATAL) << "invalid shape tensor datatype: " << dtype; + } + + return shape; +} + +String ShapeString(NDArray shape, DLDataType dtype) { return ShapeString(ToShape(shape), dtype); } + +String ShapeString(const std::vector& shape, DLDataType dtype) { + std::stringstream sizes; + sizes << dtype << "["; + for (size_t i = 0; i < shape.size(); i++) { + if (i != 0) { + sizes << ", "; + } + sizes << shape[i]; + } + sizes << "]"; + return String(sizes.str()); +} + String ShapeString(const std::vector& shapes) { std::stringstream sizes; for (const NDArray& ary : shapes) { @@ -181,7 +228,7 @@ String ShapeString(const std::vector& shapes) { String ReportNode::AsCSV() const { // get unique headers - std::unordered_set unique_headers; + std::set unique_headers; for (auto row : calls) { for (auto p : row) { @@ -296,7 +343,7 @@ String ReportNode::AsJSON() const { return s.str(); } -String ReportNode::AsTable(bool sort, bool aggregate) const { +String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) const { // aggregate calls by op hash (or op name if hash is not set) + argument shapes std::vector> aggregated_calls; if (aggregate) { @@ -311,6 +358,9 @@ String ReportNode::AsTable(bool sort, bool aggregate) const { if (frame.find("Argument Shapes") != frame.end()) { name += Downcast(frame["Argument Shapes"]); } + if (frame.find("Device") != frame.end()) { + name += Downcast(frame["Device"]); + } if (aggregates.find(name) == aggregates.end()) { aggregates[name] = {i}; @@ -365,36 +415,38 @@ String ReportNode::AsTable(bool sort, bool aggregate) const { } // compute columnwise sums - std::unordered_map col_sums; - for (auto call : aggregated_calls) { - for (auto p : call) { - if (p.second.as()) { - int64_t val = p.second.as()->value; - auto it = col_sums.find(p.first); - if (it != col_sums.end()) { - val += it->second.as()->value; - } - col_sums[p.first] = ObjectRef(make_object(val)); - } else if (p.second.as()) { - double val = p.second.as()->microseconds; - auto it = col_sums.find(p.first); - if (it != col_sums.end()) { - val += it->second.as()->microseconds; - } - col_sums[p.first] = ObjectRef(make_object(val)); - } else if (p.second.as()) { - double val = p.second.as()->percent; - auto it = col_sums.find(p.first); - if (it != col_sums.end()) { - val += it->second.as()->percent; + if (compute_col_sums) { + std::unordered_map col_sums; + for (auto call : aggregated_calls) { + for (auto p : call) { + if (p.second.as()) { + int64_t val = p.second.as()->value; + auto it = col_sums.find(p.first); + if (it != col_sums.end()) { + val += it->second.as()->value; + } + col_sums[p.first] = ObjectRef(make_object(val)); + } else if (p.second.as()) { + double val = p.second.as()->microseconds; + auto it = col_sums.find(p.first); + if (it != col_sums.end()) { + val += it->second.as()->microseconds; + } + col_sums[p.first] = ObjectRef(make_object(val)); + } else if (p.second.as()) { + double val = p.second.as()->percent; + auto it = col_sums.find(p.first); + if (it != col_sums.end()) { + val += it->second.as()->percent; + } + col_sums[p.first] = ObjectRef(make_object(val)); } - col_sums[p.first] = ObjectRef(make_object(val)); } } + col_sums["Name"] = String("Sum"); + aggregated_calls.push_back({{String("Name"), String("----------")}}); // separator + aggregated_calls.push_back(col_sums); } - col_sums["Name"] = String("Sum"); - aggregated_calls.push_back({{String("Name"), String("----------")}}); // separator - aggregated_calls.push_back(col_sums); // per-device metrics for (auto p : device_metrics) { @@ -404,18 +456,18 @@ String ReportNode::AsTable(bool sort, bool aggregate) const { } // Table formatting - std::unordered_set unique_headers; - + std::set unique_headers; for (auto row : aggregated_calls) { for (auto p : row) { unique_headers.insert(p.first); } } - std::vector headers = {"Name", "Duration (us)", - "Percent"}; // always include these headers + // always include these headers in this order + std::vector headers = {"Name", "Duration (us)", "Percent", + "Device", "Count", "Argument Shapes"}; for (auto header : unique_headers) { - if (header != "Name" && header != "Duration (us)" && header != "Percent") { + if (std::find(headers.begin(), headers.end(), header) == headers.end()) { headers.push_back(header); } } @@ -616,6 +668,7 @@ TVM_REGISTER_OBJECT_TYPE(ReportNode); TVM_REGISTER_OBJECT_TYPE(DeviceWrapperNode); TVM_REGISTER_OBJECT_TYPE(MetricCollectorNode); +TVM_REGISTER_GLOBAL("runtime.profiling.AsTable").set_body_method(&ReportNode::AsTable); TVM_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { return n->AsCSV(); }); TVM_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) { return n->AsJSON(); diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index d6575c35d10d..cd2d1332580b 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -25,6 +25,7 @@ #include "vm.h" #include +#include #include #include @@ -32,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -96,6 +98,58 @@ void VirtualMachineDebug::LoadExecutable(const Executable* exec) { } } +void VirtualMachineDebug::OpStartHook(Instruction instr) { + if (prof_ && prof_.operator*().IsRunning()) { + if (instr.op == Opcode::LoadConst) { + Device dev = GetDevice(exec_->const_device_type[instr.const_index]); + prof_.operator*().StartCall("VM::LoadConst", dev, {}); + } else if (instr.op == Opcode::DeviceCopy) { + Device dst_dev; + dst_dev.device_type = static_cast(instr.dst_device_type); + dst_dev.device_id = 0; + prof_.operator*().StartCall("VM::DeviceCopy", dst_dev, {}); + } else if (instr.op == Opcode::ReshapeTensor) { + prof_.operator*().StartCall("VM::ReshapeTensor", devices_[1], {}); + } else if (instr.op == Opcode::AllocTensor) { + auto shape = std::vector(instr.alloc_tensor.ndim); + + for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { + shape[i] = instr.alloc_tensor.shape[i]; + } + auto storage_obj = ReadRegister(instr.alloc_tensor.storage); + auto storage = Downcast(storage_obj); + prof_.operator*().StartCall( + "VM::AllocTensor", storage->buffer.device, + {{"Argument Shapes", profiling::ShapeString(shape, instr.alloc_tensor.dtype)}}); + } else if (instr.op == Opcode::AllocTensorReg) { + auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); + auto storage = Downcast(storage_obj); + Device cpu_dev = GetDevice(static_cast(kDLCPU)); + auto shape_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); + NDArray shape_tensor = Downcast(shape_obj).CopyTo(cpu_dev); + prof_.operator*().StartCall( + "VM::AllocTensorReg", storage->buffer.device, + {{"Argument Shapes", + profiling::ShapeString(shape_tensor, instr.alloc_tensor_reg.dtype)}}); + } else if (instr.op == Opcode::AllocStorage) { + auto size = LoadScalarInt(instr.alloc_storage.allocation_size); + std::ostringstream shape; + shape << DLDataType2String(instr.alloc_storage.dtype_hint) << "[" << size << "]"; + prof_.operator*().StartCall("VM::AllocStorage", + {static_cast(instr.alloc_storage.device_type), 0}, + {{"VM::Argument Shapes", String(shape.str())}}); + } else { + prof_.operator*().StartCall("VM::UnknownOp", devices_[1], {}); + } + } +} + +void VirtualMachineDebug::OpStopHook() { + if (prof_ && prof_.operator*().IsRunning()) { + prof_.operator*().StopCall(); + } +} + void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) { ICHECK(exec_); diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index 1efefda52b97..4325fa8a7999 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -51,6 +51,8 @@ class VirtualMachineDebug : public VirtualMachine { private: void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) final; + void OpStartHook(Instruction instr) final; + void OpStopHook() final; std::unordered_map packed_index_map_; dmlc::optional prof_; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index c7a1baa1430d..addd5ca5d861 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -113,6 +113,9 @@ std::vector ToShape(NDArray shape_tensor) { return shape; } +void VirtualMachine::OpStartHook(Instruction instr) {} +void VirtualMachine::OpStopHook() {} + PackedFunc VirtualMachine::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "invoke") { @@ -400,11 +403,9 @@ inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { frames_.back().register_file[r] = val; } -inline ObjectRef VirtualMachine::ReadRegister(Index r) const { - return frames_.back().register_file[r]; -} +ObjectRef VirtualMachine::ReadRegister(Index r) const { return frames_.back().register_file[r]; } -inline int64_t VirtualMachine::LoadScalarInt(Index r) const { +int64_t VirtualMachine::LoadScalarInt(Index r) const { int64_t result = 0; const auto& obj = ReadRegister(r); NDArray array = Downcast(CopyTo(obj, {kDLCPU, 0})); @@ -458,6 +459,11 @@ void VirtualMachine::RunLoop() { throw std::runtime_error("VM encountered fatal error"); } case Opcode::LoadConst: { + bool is_not_cached = const_pool_.size() <= static_cast(instr.const_index) || + !const_pool_[instr.const_index].defined(); + if (is_not_cached) { + OpStartHook(instr); + } auto constant_obj = exec_->constants[instr.const_index]; // We cache the allocated object in the constant pool. To measure, the // first iteration will set the pool up. The other iterations will @@ -471,6 +477,9 @@ void VirtualMachine::RunLoop() { const_pool_[instr.const_index] = CopyTo(constant_obj, dev); } WriteRegister(instr.dst, const_pool_[instr.const_index]); + if (is_not_cached) { + OpStopHook(); + } pc_++; goto main_loop; } @@ -560,6 +569,7 @@ void VirtualMachine::RunLoop() { goto main_loop; } case Opcode::AllocTensor: { + OpStartHook(instr); auto shape = std::vector(instr.alloc_tensor.ndim); for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { @@ -572,10 +582,12 @@ void VirtualMachine::RunLoop() { auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor.dtype); WriteRegister(instr.dst, obj); + OpStopHook(); pc_++; goto main_loop; } case Opcode::AllocTensorReg: { + OpStartHook(instr); Device cpu_dev = GetDevice(static_cast(kDLCPU)); auto shape_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); NDArray shape_tensor = Downcast(CopyTo(shape_obj, cpu_dev)); @@ -586,6 +598,7 @@ void VirtualMachine::RunLoop() { auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor_reg.dtype); WriteRegister(instr.dst, obj); + OpStopHook(); pc_++; goto main_loop; } @@ -609,6 +622,7 @@ void VirtualMachine::RunLoop() { goto main_loop; } case Opcode::AllocStorage: { + OpStartHook(instr); auto size = LoadScalarInt(instr.alloc_storage.allocation_size); auto alignment = instr.alloc_storage.alignment; @@ -625,6 +639,7 @@ void VirtualMachine::RunLoop() { storage_obj->buffer = alloc->Alloc(size, alignment, instr.alloc_storage.dtype_hint); Storage storage(storage_obj); WriteRegister(instr.dst, storage); + OpStopHook(); pc_++; goto main_loop; } @@ -656,6 +671,7 @@ void VirtualMachine::RunLoop() { } } case Opcode::ReshapeTensor: { + OpStartHook(instr); Device cpu_dev = GetDevice(static_cast(kDLCPU)); auto tensor_obj = ReadRegister(instr.reshape_tensor.tensor); NDArray tensor_arr = Downcast(tensor_obj); @@ -671,10 +687,12 @@ void VirtualMachine::RunLoop() { // Reshape the input tensor auto out_tensor = tensor_arr.CreateView(shape, tensor_arr->dtype); WriteRegister(instr.dst, out_tensor); + OpStopHook(); pc_++; goto main_loop; } case Opcode::DeviceCopy: { + OpStartHook(instr); auto tensor_src = ReadRegister(instr.src); NDArray src_data = Downcast(tensor_src); Device src_dev = src_data->device; @@ -686,6 +704,7 @@ void VirtualMachine::RunLoop() { NDArray dst_data = src_data.CopyTo(dst_dev); WriteRegister(instr.dst, dst_data); + OpStopHook(); pc_++; goto main_loop; } diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 5a4aa39f01b4..41221ad8a33e 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -47,6 +47,11 @@ runtime::Module Build(IRModule mod, Target target) { mod = tir::transform::SkipAssert()(mod); } + auto target_attr_map = tvm::TargetKind::GetAttrMap("TIRToRuntime"); + if (target_attr_map.count(target->kind)) { + return target_attr_map[target->kind](mod, target); + } + // the build function. std::string build_f_name = "target.build." + target->kind->name; const PackedFunc* bf = runtime::Registry::Get(build_f_name); diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index bd22532d998c..5414366c1cd7 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -644,7 +644,7 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::V } else { ICHECK(t.is_handle()); buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); - buf = builder_->CreateInBoundsGEP(t_void_p_, buf, index); + buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index); return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); } } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 12fbf2c3e42c..6c64f6798e47 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -77,6 +77,8 @@ void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, this->InitTarget(tm); } +void CodeGenLLVM::SetFastMathFlag(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); } + void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { module_->setTargetTriple(tm->getTargetTriple().str()); module_->setDataLayout(tm->createDataLayout()); @@ -343,7 +345,26 @@ void CodeGenLLVM::Optimize() { // place optimization pass llvm::PassManagerBuilder builder; - builder.OptLevel = 3; + + // Use the same opt-level as specified in TargetMachine for running passes + llvm::CodeGenOpt::Level opt_level = target_machine_->getOptLevel(); + + switch (opt_level) { + case llvm::CodeGenOpt::Level::None: + builder.OptLevel = 0; + break; + case llvm::CodeGenOpt::Level::Less: + builder.OptLevel = 1; + break; + + case llvm::CodeGenOpt::Level::Default: + builder.OptLevel = 2; + break; + + default: + // CodeGenOpt::Level::Aggressive + builder.OptLevel = 3; + } #if TVM_LLVM_VERSION >= 50 builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false); @@ -410,7 +431,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { } else { return etype; } -} +} // namespace codegen llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { if (auto* ptr = type.as()) { @@ -626,6 +647,20 @@ llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { } llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { + // To allow creating vectors from scalars, convert any scalars in "vecs" to single-lane + // LLVM vector types. + for (size_t i = 0, e = vecs.size(); i != e; ++i) { + llvm::Value* v = vecs[i]; + if (!v->getType()->isVectorTy()) { +#if TVM_LLVM_VERSION >= 110 + llvm::Type* vec_ty = llvm::FixedVectorType::get(v->getType(), 1); +#else + llvm::Type* vec_ty = llvm::VectorType::get(v->getType(), 1); +#endif + vecs[i] = builder_->CreateInsertElement(llvm::UndefValue::get(vec_ty), v, ConstInt32(0)); + } + } + // concat vector, tree shape reduction int total_lanes = 0; diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 177b53056354..4a9df65951c0 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -78,6 +78,13 @@ class CodeGenLLVM : public ExprFunctor, */ virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, bool target_c_runtime); + + /*! + * \brief Turn on fast math flags for floating point operations. + * \param fmf FastMathFlags to use for code generation. + */ + void SetFastMathFlag(llvm::FastMathFlags fmf); + /*! * \brief Compile and add function f to the current module. * \param f The function to be added. diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc index be80a8bc767e..06b2be2d9fb6 100644 --- a/src/target/llvm/llvm_common.cc +++ b/src/target/llvm/llvm_common.cc @@ -106,6 +106,8 @@ void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::stri #if TVM_LLVM_VERSION < 50 opt.LessPreciseFPMADOption = true; #endif + // In clang, these are fed from LangOpts which describe language specific features + // TODO(AndrewZhaoLuo): figure out how these relate to fast math flags opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; @@ -139,8 +141,22 @@ std::unique_ptr GetLLVMTargetMachine(const Target& target, ICHECK(allow_null) << err << " target_triple=" << target_triple; return nullptr; } - llvm::TargetMachine* tm = - llvm_target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_); + + Integer llvm_opt_level = target->GetAttr("opt-level").value_or(Integer(3)); + llvm::CodeGenOpt::Level llvm_opt; + if (llvm_opt_level <= 0) { + llvm_opt = llvm::CodeGenOpt::None; + } else if (llvm_opt_level == 1) { + llvm_opt = llvm::CodeGenOpt::Less; + } else if (llvm_opt_level == 2) { + llvm_opt = llvm::CodeGenOpt::Default; + } else { + // llvm_opt_level >= 3 + llvm_opt = llvm::CodeGenOpt::Aggressive; + } + + llvm::TargetMachine* tm = llvm_target->createTargetMachine( + target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_, llvm::CodeModel::Small, llvm_opt); return std::unique_ptr(tm); } diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index b967c7ad44e0..fcc44fb8f95c 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -72,7 +72,11 @@ #include #include #include +#if TVM_LLVM_VERSION >= 140 +#include +#else #include +#endif #include #include #include diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 0e4bca4396f5..86079b25aa90 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -258,8 +258,53 @@ class LLVMModuleNode final : public runtime::ModuleNode { // makes sense when we start to use multiple modules. cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); - cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); + // See https://llvm.org/docs/LangRef.html#fast-math-flags for details + Bool fast_math_all = target->GetAttr("fast-math").value_or(Bool(false)); + Bool fast_math_nnan = target->GetAttr("fast-math-nnan").value_or(Bool(false)); + Bool fast_math_ninf = target->GetAttr("fast-math-ninf").value_or(Bool(false)); + Bool fast_math_nsz = target->GetAttr("fast-math-nsz").value_or(Bool(false)); + Bool fast_math_arcp = target->GetAttr("fast-math-arcp").value_or(Bool(false)); + + llvm::FastMathFlags fmf; + if (fast_math_all) { +#if TVM_LLVM_VERSION >= 60 + fmf.setFast(); +#else + fmf.setUnsafeAlgebra(); +#endif + } + + if (fast_math_nnan) { + fmf.setNoNaNs(); + } + if (fast_math_ninf) { + fmf.setNoInfs(); + } + if (fast_math_nsz) { + fmf.setNoSignedZeros(); + } + if (fast_math_arcp) { + fmf.setAllowReciprocal(); + } + +#if TVM_LLVM_VERSION >= 60 + Bool fast_math_contract = target->GetAttr("fast-math-contract").value_or(Bool(false)); + Bool fast_math_afn = target->GetAttr("fast-math-afn").value_or(Bool(false)); + Bool fast_math_reassoc = target->GetAttr("fast-math-reassoc").value_or(Bool(false)); + if (fast_math_contract) { + fmf.setAllowContract(true); + } + if (fast_math_afn) { + fmf.setApproxFunc(); + } + if (fast_math_reassoc) { + fmf.setAllowReassoc(); + } +#endif + cg->SetFastMathFlag(fmf); + + cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); } diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 10a437a547c1..4ff1c6ef61ed 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -35,7 +35,7 @@ namespace tvm { namespace codegen { -class CodeGenCHost final : public CodeGenC { +class CodeGenCHost : public CodeGenC { public: CodeGenCHost(); void Init(bool output_ssa, bool emit_asserts, std::string target_str); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index d719386d204b..4403af26d1a8 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -49,6 +49,14 @@ Array TargetKindRegEntry::ListTargetKinds() { return TargetKindRegistry::Global()->ListAllNames(); } +Map TargetKindRegEntry::ListTargetKindOptions(const TargetKind& target_kind) { + Map options; + for (const auto& kv : target_kind->key2vtype_) { + options.Set(kv.first, kv.second.type_key); + } + return options; +} + TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const String& target_kind_name) { return TargetKindRegistry::Global()->RegisterOrGet(target_kind_name); } @@ -222,6 +230,15 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("link-params", Bool(false)) .add_attr_option("unpacked-api") .add_attr_option("interface-api") + // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags + .add_attr_option("fast-math") // implies all the below + .add_attr_option("fast-math-nnan") + .add_attr_option("fast-math-ninf") + .add_attr_option("fast-math-nsz") + .add_attr_option("fast-math-arcp") + .add_attr_option("fast-math-contract") + .add_attr_option("fast-math-reassoc") + .add_attr_option("opt-level") .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("c", kDLCPU) @@ -359,5 +376,7 @@ TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option>("de /********** Registry **********/ TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds); +TVM_REGISTER_GLOBAL("target.ListTargetKindOptions") + .set_body_typed(TargetKindRegEntry::ListTargetKindOptions); } // namespace tvm diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index a47556bac101..657dc121961c 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -48,7 +48,7 @@ class ProducerToBufferTransformer : public StmtExprMutator { const std::unordered_map& tensor2buffers_; }; -/*! \brief Helper data structural to store informations. */ +/*! \brief Helper data structure to store information. */ struct CreateFuncInfo { /*! \brief The Tensor arg_list. */ Array arg_list; @@ -102,12 +102,6 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te:: f_push_block_vars(compute_op->axis); f_push_block_vars(compute_op->reduce_axis); - // If we have a rank 0 tensor then we manifest it as a rank 1 buffer with a single element. - if (compute_op->axis.size() == 0) { - iter_vars.push_back(IterVar(Range::FromMinExtent(0, 1), Var(), IterVarType::kDataPar)); - bindings.push_back(Var()); - } - // Step 2. Declare buffer and update op2buffers Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); info->tensor2buffers[tensor] = buffer; diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 447fc501d03b..9e2d3d0e725f 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -337,7 +337,8 @@ void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, } ICHECK(expr_equal(lhs, rhs)) << "Failed to match the compute with TensorIntrin " << intrin->name << "'s declaration " - << " provided= " << lhs << ", intrin= " << rhs; + << " provided= " << lhs << ", intrin= " << rhs + << ", running this stage: " << stage; } } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index afc5c36ebb92..1d7c959d990d 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -90,6 +90,18 @@ Var Var::copy_with_suffix(const String& suffix) const { return Var(new_ptr); } +Var Var::copy_with_dtype(DataType dtype) const { + const VarNode* node = get(); + ObjectPtr new_ptr; + if (auto* ptr = this->as()) { + new_ptr = make_object(*ptr); + } else { + new_ptr = make_object(*node); + } + new_ptr->dtype = std::move(dtype); + return Var(new_ptr); +} + TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type, Span span) { if (type.IsObjectRef()) { @@ -904,6 +916,35 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // CommReducer CommReducer::CommReducer(Array lhs, Array rhs, Array result, Array identity_element, Span span) { + size_t n_group = result.size(); + CHECK_EQ(lhs.size(), n_group) << "ValueError: The number of vars in `lhs` must equal to the " + "number of elements in `results`"; + CHECK_EQ(rhs.size(), n_group) << "ValueError: The number of vars in `rhs` must equal to the " + "number of elements in `results`"; + CHECK_EQ(identity_element.size(), n_group) + << "ValueError: The number of identities must equal to the number of elements in `results`"; + + // Change the dtype of input vars to adapt to the dtype of identities + ArrayNode* p_lhs = lhs.CopyOnWrite(); + ArrayNode* p_rhs = rhs.CopyOnWrite(); + std::unordered_map var_map; + var_map.reserve(n_group * 2); + for (int i = 0; i < static_cast(n_group); ++i) { + DataType dtype = identity_element[i].dtype(); + Var l = lhs[i].copy_with_dtype(dtype); + Var r = rhs[i].copy_with_dtype(dtype); + var_map[lhs[i].get()] = l; + var_map[rhs[i].get()] = r; + + p_lhs->SetItem(i, l); + p_rhs->SetItem(i, r); + } + + ArrayNode* p_result = result.CopyOnWrite(); + for (int i = 0; i < static_cast(n_group); ++i) { + p_result->SetItem(i, Substitute(result[i], var_map)); + } + auto node = make_object(); node->lhs = lhs; node->rhs = rhs; diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index f265a8ae2b1b..7e3d3d107507 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -44,21 +44,11 @@ class ScriptCompleter : public StmtMutator { Map* buffer_var_map_; Stmt VisitStmt_(const BlockRealizeNode* op) override { contains_block = true; - Stmt body = StmtMutator::VisitStmt_(op); - if (!op->iter_values.empty() && !op->iter_values[0].dtype().is_int()) { - auto block_with_binding = CopyOnWrite(Downcast(body).get()); - std::vector bindings; - for (size_t i = 0; i < op->iter_values.size(); ++i) { - bindings.push_back(Var("i" + std::to_string(i))); - } - block_with_binding->iter_values = bindings; - body = BlockRealize(block_with_binding); - for (int i = op->iter_values.size() - 1; i >= 0; --i) { - body = For(Downcast(bindings[i]), op->block->iter_vars[i]->dom->min, - op->block->iter_vars[i]->dom->extent, {}, body); - } + for (const PrimExpr& value : op->iter_values) { + CHECK(value.dtype().is_int()) + << "BlockRealize iter_value expected a IntImm, but got " << value.dtype(); } - return body; + return StmtMutator::VisitStmt_(op); } Stmt VisitStmt_(const BlockNode* op) override { @@ -122,7 +112,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { // generate surrounding loops automatically Stmt res = script_completer(func->body); // generate root block automatically - if (script_completer.contains_block && !contain_root) { + if ((script_completer.contains_block || root_allocates.size()) && !contain_root) { res = Block({}, {}, {}, "root", res, NullOpt, root_allocates); res = BlockRealize({}, Bool(true), Downcast(res)); } diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index d8dcf57b91e4..eb72773ffedb 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -24,29 +24,37 @@ namespace tir { String ScheduleError::RenderReport(const String& primitive) const { IRModule mod = this->mod(); std::ostringstream os; - os << "ScheduleError: An error occurred in the schedule primitive '" << primitive - << "'.\n\nThe IR is:\n" - << AsTVMScript(mod); + + // get locations of interest Array locs = LocationsOfInterest(); + std::unordered_map loc_obj_to_name; int n_locs = locs.size(); - std::vector roi_names; - roi_names.reserve(n_locs); - if (n_locs > 0) { - os << "Regions of interest:\n"; - for (const ObjectRef& obj : locs) { - String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size()); - os << name << "\n" << obj; - roi_names.emplace_back(std::move(name)); - } - os << "\n"; - } std::string msg = DetailRenderTemplate(); - for (int i = 0; i < n_locs; ++i) { - std::string src = "{" + std::to_string(i) + "}"; - for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { - msg.replace(pos, src.length(), roi_names[i]); + if (n_locs > 0) { + for (int i = 0; i < n_locs; ++i) { + std::string name = locs[i]->GetTypeKey() + '#' + std::to_string(i); + std::string src = "{" + std::to_string(i) + "}"; + for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { + msg.replace(pos, src.length(), name); + } + loc_obj_to_name.emplace(locs[i], std::move(name)); } } + + // print IR module + runtime::TypedPackedFunc annotate = + runtime::TypedPackedFunc( + [&loc_obj_to_name](const Stmt& expr) -> std::string { + auto it = loc_obj_to_name.find(Downcast(expr)); + if (it == loc_obj_to_name.end()) return ""; + return it->second; + }); + + os << "ScheduleError: An error occurred in the schedule primitive '" << primitive + << "'.\n\nThe IR with diagnostic is:\n" + << AsTVMScriptWithDiagnostic(mod, "tir", false, annotate); + + // print error message os << "Error message: " << msg; return os.str(); } diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 008d47792f69..55869e12b6b2 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -121,6 +121,11 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind runtime::ThreadScope thread_scope) { PreOrderVisit(loop, [&](const ObjectRef& node) { if (const auto* realize = node.as()) { + // If this block doesn't have corresponding StmtSRef in the schedule state, it must be a block + // inside `tir.init()`. We don't check the condition for such blocks. + if (!self->stmt2ref.count(realize->block.get())) { + return false; + } CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, GetRef(realize), thread_scope); } diff --git a/src/tir/transforms/convert_for_loops_serial.cc b/src/tir/transforms/convert_for_loops_serial.cc new file mode 100644 index 000000000000..d01ae8a45113 --- /dev/null +++ b/src/tir/transforms/convert_for_loops_serial.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/transforms/convert_for_loops_serial.cc + * \brief Convert all for loops to serial for lesser memory consumption + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class ForLoopSerialConverter : public StmtExprMutator { + public: + ForLoopSerialConverter() = default; + Stmt operator()(const PrimFunc& func); + + private: + Stmt VisitStmt_(const ForNode* op) override; +}; + +Stmt ForLoopSerialConverter::VisitStmt_(const ForNode* op) { + if (op->kind == ForKind::kParallel) { + return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body, op->thread_binding, + op->annotations, op->span); + } + return StmtExprMutator::VisitStmt_(op); +} + +Stmt ForLoopSerialConverter::operator()(const PrimFunc& func) { + return this->VisitStmt(func->body); +} + +PrimFunc ConvertForLoopsToSerial(PrimFunc func) { + PrimFuncNode* fptr = func.CopyOnWrite(); + fptr->body = ForLoopSerialConverter()(func); + return func; +} + +namespace transform { + +Pass ConvertForLoopsToSerial() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return ConvertForLoopsToSerial(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ConvertForLoopsToSerial", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ConvertForLoopsToSerial") + .set_body_typed(ConvertForLoopsToSerial); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 6a3ce596c2fe..ccc660509ca1 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -60,6 +60,19 @@ using runtime::ThreadScope; */ class BufferShapeLegalize : public StmtExprMutator { public: + static transform::Pass Pass() { + auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { + IRVisitorWithAnalyzer bound_analyzer; + + bound_analyzer(func->body); + + auto fptr = func.CopyOnWrite(); + fptr->body = BufferShapeLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferShapeLegalize", {}); + } + explicit BufferShapeLegalize(const Map& extern_buffer_map, IRVisitorWithAnalyzer* bound_analyzer) : bound_analyzer_(bound_analyzer) { @@ -383,6 +396,19 @@ class BufferShapeLegalize : public StmtExprMutator { */ class BufferStrideLegalize : public StmtExprMutator { public: + static transform::Pass Pass() { + auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { + IRVisitorWithAnalyzer bound_analyzer; + + bound_analyzer(func->body); + + auto fptr = func.CopyOnWrite(); + fptr->body = BufferStrideLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferStrideLegalize", {}); + } + explicit BufferStrideLegalize(const Map& extern_buffer_map, IRVisitorWithAnalyzer* bound_analyzer) : bound_analyzer_(bound_analyzer) { @@ -565,6 +591,15 @@ class BufferStrideLegalize : public StmtExprMutator { */ class ThreadScopePropagate : public StmtExprMutator { public: + static transform::Pass Pass() { + auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { + auto fptr = func.CopyOnWrite(); + fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body)); + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.ThreadScopePropagate", {}); + } + explicit ThreadScopePropagate(const Map& extern_buffer_map) { // External buffers shouldn't be overwritten, even if they have a // BufferRealizeNode. @@ -718,6 +753,19 @@ class ThreadScopePropagate : public StmtExprMutator { */ class BufferBindUnwrapper : public StmtExprMutator { public: + static transform::Pass Pass() { + auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { + IRVisitorWithAnalyzer bound_analyzer; + + bound_analyzer(func->body); + + auto fptr = func.CopyOnWrite(); + fptr->body = BufferBindUnwrapper(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferBindUnwrapper", {}); + } + explicit BufferBindUnwrapper(const Map& extern_buffer_map, IRVisitorWithAnalyzer* bound_analyzer) : bound_analyzer_(bound_analyzer) { @@ -1030,6 +1078,20 @@ class BufferBindUnwrapper : public StmtExprMutator { class StorageFlattener : public StmtExprMutator { public: + static transform::Pass Pass(int cache_line_size, bool create_bound_attributes) { + auto pass_func = [=](PrimFunc func, IRModule m, transform::PassContext ctx) { + IRVisitorWithAnalyzer bound_analyzer; + + bound_analyzer(func->body); + + auto fptr = func.CopyOnWrite(); + fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, + &bound_analyzer)(std::move(fptr->body)); + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {}); + } + explicit StorageFlattener(const Map& extern_buffer_map, int cache_line_size, bool create_bound_attributes, IRVisitorWithAnalyzer* bound_analyzer) : bound_analyzer_(bound_analyzer), create_bound_attributes_(create_bound_attributes) { @@ -1355,6 +1417,19 @@ class StorageFlattener : public StmtExprMutator { */ class AssertSimplifier : public StmtMutator { public: + static transform::Pass Pass() { + auto pass_func = [=](PrimFunc func, IRModule m, transform::PassContext ctx) { + IRVisitorWithAnalyzer bound_analyzer; + + bound_analyzer(func->body); + + auto fptr = func.CopyOnWrite(); + fptr->body = AssertSimplifier(&bound_analyzer)(std::move(fptr->body)); + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.AssertSimplifier", {}); + } + explicit AssertSimplifier(IRVisitorWithAnalyzer* bound_analyzer) : bound_analyzer_(bound_analyzer) {} @@ -1409,30 +1484,25 @@ class AssertSimplifier : public StmtMutator { // We do support a few relaxed case, such as binding a // region with shape [1, 1, n, m] to buffer with shape [n, m] PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { - // Only apply this pass to TIR from TE schedules + // Only apply this pass to TIR from TE schedules. Because this is a + // per-function attribute, we can't just check it once for the + // entire module and apply the Sequential transform. Optional from_legacy_te_schedule = func->GetAttr("from_legacy_te_schedule", Bool(false)); if (from_legacy_te_schedule.value()) { - auto fptr = func.CopyOnWrite(); - - IRVisitorWithAnalyzer bound_analyzer; - bound_analyzer(fptr->body); - - fptr->body = BufferShapeLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); - - auto stride_legalize = BufferStrideLegalize(fptr->buffer_map, &bound_analyzer); - fptr->body = stride_legalize(std::move(fptr->body)); - fptr->buffer_map = stride_legalize.UpdatedExternBufferMap(); - - fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body)); - - fptr->body = BufferBindUnwrapper(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); - - fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, - &bound_analyzer)(std::move(fptr->body)); - - fptr->body = AssertSimplifier(&bound_analyzer)(std::move(fptr->body)); - - return func; + auto seq = transform::Sequential( + { + BufferShapeLegalize::Pass(), + BufferStrideLegalize::Pass(), + ThreadScopePropagate::Pass(), + BufferBindUnwrapper::Pass(), + StorageFlattener::Pass(cache_line_size, create_bound_attributes), + AssertSimplifier::Pass(), + }, + "tir.StorageFlatten_impl"); + GlobalVar dummy_func_name("dummy_func"); + IRModule mod(Map({{dummy_func_name, func}})); + mod = seq(mod); + return Downcast(mod->Lookup(dummy_func_name)); } else { return func; } diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 2e8ba11c0262..6106eb2225e1 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -152,8 +152,18 @@ TEST(TargetCreation, DeduplicateKeys) { ICHECK_EQ(target->GetAttr("link-params"), false); } -TEST(TargetKindRegistryListTargetKinds, Basic) { +TEST(TargetKindRegistry, ListTargetKinds) { Array names = TargetKindRegEntry::ListTargetKinds(); ICHECK_EQ(names.empty(), false); ICHECK_EQ(std::count(std::begin(names), std::end(names), "llvm"), 1); } + +TEST(TargetKindRegistry, ListTargetOptions) { + TargetKind llvm = TargetKind::Get("llvm").value(); + Map attrs = TargetKindRegEntry::ListTargetKindOptions(llvm); + ICHECK_EQ(attrs.empty(), false); + + ICHECK_EQ(attrs["mattr"], "Array"); + ICHECK_EQ(attrs["mcpu"], "runtime.String"); + ICHECK_EQ(attrs["system-lib"], "IntImm"); +} diff --git a/tests/micro/arduino/conftest.py b/tests/micro/arduino/conftest.py index bb9c69bf4a0e..73361774821b 100644 --- a/tests/micro/arduino/conftest.py +++ b/tests/micro/arduino/conftest.py @@ -17,8 +17,9 @@ import datetime import pathlib - +import json import pytest + import tvm.target.target from tvm.micro import project from tvm import micro, relay @@ -34,19 +35,16 @@ / "template_project" ).resolve() +BOARDS = TEMPLATE_PROJECT_DIR / "boards.json" + def arduino_boards() -> dict: """Returns a dict mapping board to target model""" - template = project.TemplateProject.from_directory(TEMPLATE_PROJECT_DIR) - project_options = template.info()["project_options"] - for option in project_options: - if option["name"] == "arduino_board": - boards = option["choices"] - if option["name"] == "arduino_model": - models = option["choices"] - - arduino_boards = {boards[i]: models[i] for i in range(len(boards))} - return arduino_boards + with open(BOARDS) as f: + board_properties = json.load(f) + + boards_model = {board: info["model"] for board, info in board_properties.items()} + return boards_model ARDUINO_BOARDS = arduino_boards() diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index be1f231156ad..089598007651 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -374,6 +374,9 @@ def test_tensors(sess): @tvm.testing.requires_micro def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): """Test AutoTune for microTVM Zephyr""" + if board in ["qemu_riscv32", "qemu_riscv64"]: + pytest.xfail(f"Autotune fails on {board}.") + model = test_utils.ZEPHYR_BOARDS[board] build_config = {"debug": tvm_debug} diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot.py index 5bc665b748f6..f79aa8bd70d2 100644 --- a/tests/micro/zephyr/test_zephyr_aot.py +++ b/tests/micro/zephyr/test_zephyr_aot.py @@ -47,6 +47,8 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug): "nrf5340dk_nrf5340_cpuapp", "nucleo_l4r5zi", "qemu_cortex_r5", + "qemu_riscv32", + "qemu_riscv64", ]: pytest.skip(msg="Model does not fit.") @@ -55,8 +57,8 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug): output_shape = (1, 10) build_config = {"debug": tvm_debug} - model_url = "https://github.com/eembc/ulpmark-ml/raw/fc1499c7cc83681a02820d5ddf5d97fe75d4f663/base_models/ic01/ic01_fp32.tflite" - model_path = download_testdata(model_url, "ic01_fp32.tflite", module="model") + model_url = "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/model/image_classification_fp32.tflite" + model_path = download_testdata(model_url, "image_classification_fp32.tflite", module="model") # Import TFLite model tflite_model_buf = open(model_path, "rb").read() diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py index f151a85ec5b1..e582874d1de2 100644 --- a/tests/python/contrib/test_arm_compute_lib/infrastructure.py +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -184,7 +184,7 @@ def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_parti ), "Got {} Arm Compute Library partitions, expected {}".format( partition_count, acl_partitions ) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() return relay.build(mod, target=target, params=params) diff --git a/tests/python/contrib/test_arm_compute_lib/test_pooling.py b/tests/python/contrib/test_arm_compute_lib/test_pooling.py index 9deaa758639e..b174f9a78866 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_pooling.py +++ b/tests/python/contrib/test_arm_compute_lib/test_pooling.py @@ -123,6 +123,7 @@ def _get_expected_pooling_codegen( "num_inputs": "1", "num_outputs": "1", "layout": [["NHWC"]], + "out_layout": [[""]], "shape": [[list(output_shape)]], "dtype": [[dtype]], "padding": [[str(p) for p in padding]], @@ -149,6 +150,7 @@ def _get_expected_global_pooling_codegen(shape, dtype, typef): "num_inputs": "1", "num_outputs": "1", "layout": [["NHWC"]], + "out_layout": [[""]], "shape": [[[1, 1, 1, shape[3]]]], "dtype": [[dtype]], }, diff --git a/tests/python/contrib/test_bnns/infrastructure.py b/tests/python/contrib/test_bnns/infrastructure.py index 46bd049402a9..5a12b0487408 100644 --- a/tests/python/contrib/test_bnns/infrastructure.py +++ b/tests/python/contrib/test_bnns/infrastructure.py @@ -142,7 +142,7 @@ def build_module(mod, target, params=None, enable_bnns=True, tvm_ops=0): with tvm.transform.PassContext(opt_level=3): if enable_bnns: mod = partition_for_bnns(mod) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() return relay.build(mod, target=target, target_host=target, params=params) diff --git a/tests/python/contrib/test_ethosn/infrastructure.py b/tests/python/contrib/test_ethosn/infrastructure.py index 92e8f11a2312..c5ebde4b9c61 100644 --- a/tests/python/contrib/test_ethosn/infrastructure.py +++ b/tests/python/contrib/test_ethosn/infrastructure.py @@ -149,7 +149,7 @@ def build(mod, params, npu=True, expected_host_ops=0, npu_partitions=1): npu_partitions : int, optional The number of Ethos-N partitions expected. """ - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() with tvm.transform.PassContext( opt_level=3, config={"relay.ext.ethos-n.options": {"variant": get_ethosn_variant()}} ): @@ -262,7 +262,7 @@ def test_error(mod, params, err_msg): except tvm.error.TVMError as e: caught = e.args[0] finally: - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() assert caught is not None assert err_msg in caught, caught diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 8b0d3063a696..01a7ceb9ed56 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -29,7 +29,9 @@ import os import struct import numpy +import math from enum import IntEnum +import tensorflow as tf from ethosu.vela.register_command_stream_generator import CmdMode from ethosu.vela.register_command_stream_generator import cmd0 @@ -66,26 +68,6 @@ def __init__(self): self.npu_ops = set() -def parse_relay_tflite_model(tflite_model, input_tensor, input_shape, input_dtype): - mod_, params_ = relay.frontend.from_tflite( - tflite_model, - shape_dict={input_tensor: input_shape}, - dtype_dict={input_tensor: input_dtype}, - ) - return mod_, params_ - - -def parse_tflite_model(model_file): - try: - import tflite - - return tflite.Model.GetRootAsModel(model_file, 0) - except AttributeError: - import tflite.Model - - return tflite.Model.Model.GetRootAsModel(model_file, 0) - - def print_payload(payload): cmds = deserialize_command_stream(payload) for cmd_val in cmds: @@ -270,6 +252,58 @@ def flatten_numpy_data(data): return reshaped_data +class InputGenerator: + def __init__(self, random_state): + self._random_state = random_state + + def generate(self, size, dtype): + if dtype == numpy.float32: + print("random float32") + return self._random_state.uniform(-1, 1, size).astype(dtype) + else: + print("random (u)int min=%d max=%d", numpy.iinfo(dtype).min, numpy.iinfo(dtype).max) + low = numpy.iinfo(dtype).min + high = numpy.iinfo(dtype).max + 1 + return self._random_state.randint(low, high, size, dtype) + + +def generate_ref_data_tflite(model): + """ + This method generates reference data by running the specified model on tflite with random input data. + The random input data and generated output data are returned. + """ + expected_output_data = {} + interpreter = tf.lite.Interpreter(model_content=model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Initialize random generators with a fixed seed to get deterministic results + seed = 0 + random_state = numpy.random.RandomState(seed) + + inputgen = InputGenerator(random_state) + + # Generate input data + input_data = { + input_detail["name"]: inputgen.generate( + input_detail["shape"], + input_detail["dtype"], + ) + for input_detail in input_details + } + for index, value in enumerate(input_data.values()): + interpreter.set_tensor(index, value) + interpreter.invoke() + + expected_output_data = [ + interpreter.get_tensor(output_detail["index"]) for output_detail in output_details + ] + + return input_data, expected_output_data + + def generate_weights_data(shape, dtype): size = 1 for dim in shape: @@ -278,7 +312,7 @@ def generate_weights_data(shape, dtype): def get_convolutional_args(call, include_buffers=False, remove_constants=False): - """A method to extract the arguments from conv2d or depthwise2d extern call.""" + """A method to extract the arguments from conv2d or depthwise_conv2d extern call.""" args = call.args conv_args = [] remove_indices = [0] @@ -299,6 +333,44 @@ def get_convolutional_args(call, include_buffers=False, remove_constants=False): return conv_args +def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation=[1, 1]): + assert len(strides) == 2 + assert len(dilation) == 2 + assert len(kernel_shape) == 2 + if padding.lower() == "valid": + h = math.ceil((ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0]) / strides[0]) + w = math.ceil((ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1]) / strides[1]) + if padding.lower() == "same": + h = math.ceil(ifm_shape[1] / strides[0]) + w = math.ceil(ifm_shape[2] / strides[1]) + ofm_shape = [ifm_shape[0], h, w, ifm_shape[3]] + return ofm_shape + + +def compute_padding_shape(ifm_shape, ofm_shape, padding, kernel_shape, strides, dilation=[1, 1]): + assert len(strides) == 2 + assert len(dilation) == 2 + assert len(kernel_shape) == 2 + if padding.lower() == "valid": + return [0, 0, 0, 0] + if padding.lower() == "same": + effective_kernel_shape = [ + dilation[0] * (kernel_shape[0] - 1) + 1, + dilation[1] * (kernel_shape[1] - 1) + 1, + ] + pad_along_height = max( + (ofm_shape[1] - 1) * strides[0] + effective_kernel_shape[0] - ifm_shape[1], 0 + ) + pad_along_width = max( + (ofm_shape[2] - 1) * strides[1] + effective_kernel_shape[1] - ifm_shape[2], 0 + ) + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + return [pad_top, pad_left, pad_bottom, pad_right] + + def make_ethosu_conv2d( ifm, ifm_channels, @@ -343,3 +415,48 @@ def make_ethosu_conv2d( ofm_layout=ofm_layout, ) return conv + + +def make_ethosu_depthwise_conv2d( + ifm, + channels, + kernel_shape, + padding, + strides, + dilation, + activation="NONE", + ifm_layout="NHWC", + ofm_layout="NHWC", + weight_dtype="int8", +): + # params + weight_shape = (channels, kernel_shape[0], kernel_shape[1], 1) + padding = get_pad_tuple(padding, kernel_shape) + + scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8") + scale_bias = relay.const(scale_bias_data, dtype="uint8") + weight_data = generate_weights_data(weight_shape, weight_dtype) + weight = relay.const(weight_data, dtype=weight_dtype) + depthwise = ethosu_ops.ethosu_depthwise_conv2d( + ifm, + weight, + scale_bias, + lut=relay.const([], dtype="int8"), + ifm_scale=0.6, + ifm_zero_point=11, + weight_zero_point=13, + ofm_scale=0.26, + ofm_zero_point=15, + kernel_shape=kernel_shape, + ofm_channels=channels, + strides=strides, + padding=padding, + dilation=dilation, + activation=activation, + clip_min=15 if activation == "CLIP" else 0, + clip_max=105 if activation == "CLIP" else 0, + upscale="NONE", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + return depthwise diff --git a/tests/python/contrib/test_ethosu/test_attr_passing.py b/tests/python/contrib/test_ethosu/test_attr_passing.py index a2fbe1888d2a..6b99a5c1e540 100644 --- a/tests/python/contrib/test_ethosu/test_attr_passing.py +++ b/tests/python/contrib/test_ethosu/test_attr_passing.py @@ -28,7 +28,9 @@ def test_compiler_attr(): } with tvm.transform.PassContext(opt_level=3, config={"relay.ext.ethosu.options": config}): with tvm.target.Target("c -device=micro_dev"): - assert util.get_accelerator_config() == config["accelerator_config"] + compiler_attrs = tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")() + accel_config_str = compiler_attrs.accelerator_config + assert accel_config_str == config["accelerator_config"] def test_compiler_attr_default(): @@ -37,7 +39,9 @@ def test_compiler_attr_default(): } with tvm.transform.PassContext(opt_level=3): with tvm.target.Target("c -device=micro_dev"): - assert util.get_accelerator_config() == default_config["accelerator_config"] + compiler_attrs = tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")() + accel_config_str = compiler_attrs.accelerator_config + assert accel_config_str == default_config["accelerator_config"] if __name__ == "__main__": diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 1944de5f94c0..4949d6814ab2 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -18,14 +18,12 @@ import pytest pytest.importorskip("ethosu.vela") -import os import numpy as np -import pathlib +import tflite.Model import tvm -import tvm.micro as micro +import tensorflow as tf from tvm import relay -from tvm.relay.backend.contrib import ethosu from tvm.relay.backend.contrib.ethosu import util from tvm.relay.op.contrib.ethosu import partition_for_ethosu from tests.python.relay.aot.aot_test_utils import generate_ref_data @@ -168,5 +166,93 @@ def create_graph_activation(input_tensor_name, input_tensor_shape, input_tensor_ infra.verify_source(compiled_models, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)]) +@pytest.mark.parametrize( + "kernel_shape, activation", + [((3, 3), "relu"), ((1, 2), None)], +) +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 2)), ((3, 2), (1, 1))]) +def test_tflite_depthwise_conv2d( + accel_type, + ifm_shape, + kernel_shape, + padding, + strides, + dilation, + activation, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def depthwise_conv2d(self, x): + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + # The input strides to the TensorFlow API needs to be of shape 1x4 + tf_strides = [1, strides[0], strides[1], 1] + op = tf.nn.depthwise_conv2d( + x, weight, strides=tf_strides, padding=padding, dilations=dilation + ) + if activation: + op = tf.nn.relu(op) + return op + + model = Model() + concrete_func = model.depthwise_conv2d.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 60ed352edcfd..5b60102162be 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -64,10 +64,10 @@ def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, def test_weight_stream_only(): - def _planner(te_graph, const_dict, sch): - weights = te_graph.inputs[1] - bias = te_graph.inputs[2] - out = te_graph.outputs[0] + def _planner(cached_func, const_dict, sch): + weights = cached_func.inputs[1] + bias = cached_func.inputs[2] + out = cached_func.outputs[0] conv_compute = Convolution2DCompute.from_output(out) co = conv_compute.split(sch, 3, 2) cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d]) @@ -208,10 +208,10 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle def test_mixed_read(): - def _planner(te_graph, const_dict, sch): - weight = te_graph.inputs[4] - scale_bias = te_graph.inputs[5] - out = te_graph.outputs[0] + def _planner(cached_func, const_dict, sch): + weight = cached_func.inputs[4] + scale_bias = cached_func.inputs[5] + out = cached_func.outputs[0] conv_compute = Convolution2DCompute.from_output(out) co = conv_compute.split(sch, 3, 2) cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 911a0e6eefc6..b9a588d4aec0 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -20,15 +20,33 @@ pytest.importorskip("ethosu.vela") import numpy as np +import tensorflow as tf +import tflite.Model import tvm from tvm import relay -from tvm.relay.backend.contrib import ethosu from tvm.relay.backend.contrib.ethosu import legalize, preprocess -from tvm.relay.dataflow_pattern import * -from tvm.relay.op.contrib.ethosu import * +from tvm.relay import dataflow_pattern +from tvm.relay.op.contrib import ethosu +from tvm.relay.build_module import bind_params_by_name from . import relay_ir_builder +from . import infra + + +def partition_ethosu_by_table(mod, pattern_table): + """In case only the legalization part is supported for an operator, we don't + want to add the operator's pattern to the pattern table so that the compiler + wouldn't attempt to offload an operator without full stack support.""" + mod = relay.transform.InferType()(mod) + mod = relay.transform.MergeComposite(pattern_table)(mod) + mod = relay.transform.AnnotateTarget("ethosu")(mod) + mod = relay.transform.MergeCompilerRegions()(mod) + mod = relay.transform.InferType()(mod) + mod = relay.transform.PartitionGraph()(mod) + mod = relay.transform.InferType()(mod) + mod = preprocess.preprocess_ext_io()(mod) + return mod def test_split_indices_legalize(): @@ -294,7 +312,7 @@ def verify_linear(ext_func, conv2d_params): ] for test_case in test_cases: mod, conv_params = test_case[0](*test_case[1]) - mod = partition_for_ethosu(mod) + mod = ethosu.partition_for_ethosu(mod) mod = legalize.LegalizeEthosUConv2D()(mod) verify_linear(mod["tvmgen_default_ethosu_main_0"], conv_params) @@ -327,12 +345,123 @@ def create_graph_single_unsupported_ifm_layout( for test_case in test_cases: mod, conv_params = test_case[0](*test_case[1]) - mod = partition_for_ethosu(mod) + mod = ethosu.partition_for_ethosu(mod) with pytest.raises( tvm._ffi.base.TVMError, match="EthosUCodegenError: Unsupported Layout NCHW" ): mod = legalize.LegalizeEthosUConv2D()(mod) +@pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)]) +@pytest.mark.parametrize("kernel_shape", [(7, 3), (22, 5)]) +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) +@pytest.mark.parametrize("activation", ["RELU", None]) +def test_tflite_depthwise_conv_2d_legalize( + ifm_shape, kernel_shape, padding, strides, dilation, activation +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def depthwise_conv2d(self, x): + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + # The input strides to the TensorFlow API needs to be of shape 1x4 + tf_strides = [1, strides[0], strides[1], 1] + op = tf.nn.depthwise_conv2d( + x, weight, strides=tf_strides, padding=padding, dilations=dilation + ) + if activation: + op = tf.nn.relu(op) + return op + + model = Model() + concrete_func = model.depthwise_conv2d.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + op = ext_func.body + ofm_channels = op.attrs.ofm_channels + + # check IFM + ifm = op.args[0].checked_type + assert list(ifm.shape) == list(ifm_shape) + assert str(ifm.dtype) == dtype + assert ifm.shape[3] == ofm_channels + + # check OFM + ofm = op.checked_type + expected_ofm_shape = infra.compute_ofm_shape( + ifm_shape, padding, kernel_shape, strides, dilation + ) + assert list(ofm.shape) == list(expected_ofm_shape) + assert str(ofm.dtype) == dtype + assert ofm.shape[3] == ofm_channels + + # check weights + weights_ohwi = op.args[1].data.asnumpy() + assert str(weights_ohwi.dtype) == dtype + assert weights_ohwi.shape[0] == ofm_channels + assert weights_ohwi.shape[1] == kernel_shape[0] + assert weights_ohwi.shape[2] == kernel_shape[1] + assert weights_ohwi.shape[3] == 1 # only depth multiplier 1 is supported + + # Check that scale_bias matches weight tensor + assert list(op.args[2].checked_type.shape)[0] == ofm_channels + + expected_padding = infra.compute_padding_shape( + ifm_shape, expected_ofm_shape, padding, kernel_shape, strides, dilation + ) + assert list(op.attrs.padding) == list(expected_padding) + assert op.attrs.ofm_channels == ofm_channels + assert list(op.attrs.strides) == list(strides) + assert list(op.attrs.dilation) == list(dilation) + if activation == "RELU": + assert str(op.attrs.activation) == "CLIP" + + depthwise_pattern_table = [ + ( + ethosu.QnnDepthwiseConv2DParams.composite_name, + ethosu.qnn_depthwise_conv2d_pattern(), + lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(), + ) + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], params) + mod = partition_ethosu_by_table(mod, depthwise_pattern_table) + + mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( + legalize.EthosuDepthwiseConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"] + ) + verify(mod["tvmgen_default_ethosu_main_0"]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py new file mode 100644 index 000000000000..b3ce74c4e84a --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +pytest.importorskip("ethosu.vela") + +import tvm +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from .infra import make_ethosu_depthwise_conv2d, get_convolutional_args + + +@pytest.mark.parametrize( + "trial", + [ + [(1, 8, 8, 3), 3, (3, 2), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC"], + [(1, 8, 8, 3), 3, (1, 1), (2, 1), (1, 1), (1, 1), "TANH", "NHWC", "NHWC"], + [(1, 8, 8, 3), 3, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC"], + [(1, 1, 1, 1), 1, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC"], + [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC"], + [(1, 8, 2, 8, 16), 18, (1, 1), (2, 1), (1, 1), (1, 1), "CLIP", "NHCWB16", "NHWC"], + [(1, 7, 9, 40), 40, (3, 2), (1, 2), (2, 1), (1, 2), "CLIP", "NHWC", "NHCWB16"], + [(1, 4, 12, 9, 16), 182, (2, 3), (6, 3), (2, 2), (1, 1), "CLIP", "NHCWB16", "NHCWB16"], + [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (2, 2), "CLIP", "NHWC", "NHWC"], + [(1, 7, 9, 41), 41, (3, 2), (1, 2), (2, 1), (2, 2), "CLIP", "NHWC", "NHCWB16"], + [ + (1, 13, 12, 19, 16), + 182, + (1, 3), + (5, 3), + (2, 1), + (2, 1), + "CLIP", + "NHCWB16", + "NHCWB16", + ], + ], +) +def test_depthwise_conv2d_single(trial): + def _get_func( + ifm_shape, + channels, + kernel_shape, + padding, + strides, + dilation, + activation, + ifm_layout, + ofm_layout, + ): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + depthwise = make_ethosu_depthwise_conv2d( + ifm, + channels, + kernel_shape, + padding, + strides, + dilation, + activation, + ifm_layout, + ofm_layout, + ) + func = relay.Function(relay.analysis.free_vars(depthwise), depthwise) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func(*trial) + mod, _ = lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_convolutional_args(stmt, remove_constants=True)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + ( + ifm_shape, + channels, + kernel_shape, + padding, + strides, + dilation, + activation, + ifm_layout, + ofm_layout, + ) = trial + dilated_kernel_h = (kernel_shape[0] - 1) * dilation[0] + 1 + dilated_kernel_w = (kernel_shape[1] - 1) * dilation[1] + 1 + if ifm_layout == "NHWC": + ifm_stride_c = 1 + ifm_stride_w = ifm_shape[3] + ifm_stride_h = ifm_shape[2] * ifm_shape[3] + ofm_height = (ifm_shape[1] - dilated_kernel_h + padding[0] + padding[0]) // strides[0] + 1 + ofm_width = (ifm_shape[2] - dilated_kernel_w + padding[1] + padding[1]) // strides[1] + 1 + else: + ifm_stride_w = 16 + ifm_stride_c = 16 * ifm_shape[3] + ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] + ofm_height = (ifm_shape[1] - dilated_kernel_h + padding[0] + padding[0]) // strides[0] + 1 + ofm_width = (ifm_shape[3] - dilated_kernel_w + padding[1] + padding[1]) // strides[1] + 1 + + if ofm_layout == "NHWC": + ofm_stride_c = 1 + ofm_stride_w = channels if ofm_width > 1 else 1 + ofm_stride_h = channels * ofm_width if ofm_height > 1 else 1 + else: + ofm_stride_w = 16 + ofm_stride_c = 16 * ofm_width + ofm_stride_h = 16 * ofm_width * ((channels - 1) // 16 + 1) + + answer = [ + "int8", + ifm_shape[1], + ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + channels, + ifm_shape[1], + 0, + ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + 0, + 0, + 0, + 0, + 0.6, + 11, + ifm_layout, + ifm_stride_h, + ifm_stride_w, + ifm_stride_c, + "int8", + ofm_height, + ofm_width, + channels, + ofm_height, + 0, + ofm_width, + 0, + 0, + 0, + 0, + 0.26, + 15, + ofm_layout, + ofm_stride_h, + ofm_stride_w, + ofm_stride_c, + kernel_shape[1], + kernel_shape[0], + strides[1], + strides[0], + dilation[1], + dilation[0], + 13, + padding[0], + padding[1], + padding[0], + padding[1], + activation, + 15 if activation == "CLIP" else 0, + 105 if activation == "CLIP" else 0, + "NONE", + ] + assert data[0] == answer, data[0] diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 8077271ed496..b04059011e8e 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -81,10 +81,10 @@ def test_inline_no_ops(): func = relay.Function(relay.analysis.free_vars(relu2), relu2) func = run_opt_pass(func, relay.transform.InferType()) - te_graph = lower_to_te(func) - sch = te.create_schedule([te_graph.outputs[0].op]) - inline_no_ops(te_graph, sch) - reshape_tensor = te_graph.outputs[0].op.input_tensors[0] + cached_func = lower_to_te(func) + sch = te.create_schedule([cached_func.outputs[0].op]) + inline_no_ops(cached_func, sch) + reshape_tensor = cached_func.outputs[0].op.input_tensors[0] slice_tensor = reshape_tensor.op.input_tensors[0].op.input_tensors[0] assert sch[reshape_tensor].attach_type == AttachType.kInline assert sch[slice_tensor].attach_type == AttachType.kInline @@ -114,11 +114,11 @@ def test_copy_constants(): func = run_opt_pass(func, relay.transform.InferType()) func, const_dict = extract_constants(func) - te_graph = lower_to_te(func) + cached_func = lower_to_te(func) - sch = te.create_schedule([te_graph.outputs[0].op]) + sch = te.create_schedule([cached_func.outputs[0].op]) planner = copy_constants() - planner(te_graph, const_dict, sch) + planner(cached_func, const_dict, sch) assert len(sch.stages) == 21 assert ".global" in sch.stages[5].op.name assert ".global" in sch.stages[7].op.name diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index b07f3a5016fa..8240b392a1cf 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -497,6 +497,81 @@ def populate_ethosu_conv2d_calls(stmt): assert w_zero_point == ref["w_zero_point"] +# fmt: off +"""A ethosu_depthwise_conv2d tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuDepthwiseConv2D: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_depthwise_conv2d: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_1, [3, 3, 2, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_2, [3, 10], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_depthwise_conv2d_1 = T.match_buffer(ethosu_depthwise_conv2d, [1, 6, 7, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 8, 3, 8, 0, 8, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.6), 11, "NHWC", 24, 3, 1, "int8", 6, 7, 3, 6, 0, 7, T.load("int8", ethosu_depthwise_conv2d_1.data, 0), 0, 0, 0, T.float32(0.26), 15, "NHWC", 21, 3, 1, 2, 3, 1, 1, 1, 1, T.load("int8", placeholder_4.data, 0), 18, 13, T.load("uint8", placeholder_5.data, 0), 30, 0, 0, 0, 0, "CLIP", 15, 105, "NONE", dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +def test_translate_ethosu_depthwise_conv2d(): + def extract_ethosu_depthwise_conv2d_extern_call(mod): + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_depthwise_conv2d_calls = list() + + def populate_ethosu_depthwise_conv2d_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_depthwise_conv2d" + ): + ethosu_depthwise_conv2d_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_depthwise_conv2d_calls) + return ethosu_depthwise_conv2d_calls[0] + + depthwise_conv2d_call = extract_ethosu_depthwise_conv2d_extern_call(SingleEthosuDepthwiseConv2D) + npu_op, w_zero_point = tir_to_cs_translator.translate_ethosu_depthwise_conv2d( + depthwise_conv2d_call + ) + + assert npu_op.ifm.data_type == vapi.NpuDataType.INT8 + assert npu_op.ifm.shape == vapi.NpuShape3D(8, 8, 3) + assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]).height_0 + assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]).height_1 + assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]).width_0 + assert npu_op.ifm.quantization == pytest.approx(vapi.NpuQuantization(0.6, 11)) + assert npu_op.ifm.layout == vapi.NpuLayout.NHWC + assert npu_op.ifm.strides == vapi.NpuShape3D(24, 3, 1) + # Compare OFM + assert npu_op.ofm.data_type == vapi.NpuDataType.INT8 + assert npu_op.ofm.shape == vapi.NpuShape3D(6, 7, 3) + assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(6, 0, 8, [0, 0, 0, 0]).height_0 + assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(6, 0, 7, [0, 0, 0, 0]).height_1 + assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(6, 0, 7, [0, 0, 0, 0]).width_0 + assert npu_op.ofm.quantization == pytest.approx(vapi.NpuQuantization(0.26, 15)) + assert npu_op.ofm.layout == vapi.NpuLayout.NHWC + assert npu_op.ofm.strides == vapi.NpuShape3D(21, 3, 1) + # Compare kernel and padding + assert ( + npu_op.kernel.__dict__ + == vapi.NpuKernel(w=2, h=3, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1).__dict__ + ) + assert npu_op.padding == vapi.NpuPadding(top=0, left=0, bottom=0, right=0) + # Compare activation + assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + assert npu_op.activation.min == 0 + assert npu_op.activation.max == pytest.approx(23.4) + # Compare ifm upscaling + assert npu_op.ifm_upscale == vapi.NpuResamplingMode.NONE + # Compare weight quantization parameters + assert w_zero_point == 13 + + def test_translate_ethosu_copy(): def extract_ethosu_copy_extern_calls(mod): """This function will obtain all ethosu_conv2d diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py b/tests/python/contrib/test_ethosu/test_type_inference.py new file mode 100644 index 000000000000..47fddad773b2 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_type_inference.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +from tvm import relay +from tvm.relay.testing import run_opt_pass +from .infra import make_ethosu_conv2d +from .infra import make_ethosu_depthwise_conv2d + + +@pytest.mark.parametrize( + ["ifm_shape", "ifm_layout"], [((1, 56, 72, 55), "NHWC"), ((1, 56, 4, 72, 16), "NHCWB16")] +) +@pytest.mark.parametrize( + "ofm_shape,ofm_layout", [((1, 54, 38, 122), "NHWC"), ((1, 54, 8, 38, 16), "NHCWB16")] +) +def test_ethosu_conv2d_type_inference( + ifm_shape, + ifm_layout, + ofm_shape, + ofm_layout, +): + ifm_channels = 55 + ofm_channels = 122 + kernel_shape = (3, 2) + padding = (0, 1, 2, 3) + strides = (1, 2) + dilation = (2, 1) + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + conv2d = make_ethosu_conv2d( + ifm, + ifm_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + f = relay.Function([ifm], conv2d) + f = run_opt_pass(f, relay.transform.InferType()) + assert tuple(f.body.checked_type.shape) == ofm_shape + + +@pytest.mark.parametrize( + "ifm_shape, ifm_layout", [((1, 46, 71, 55), "NHWC"), ((1, 46, 4, 71, 16), "NHCWB16")] +) +@pytest.mark.parametrize( + "ofm_shape, ofm_layout", [((1, 44, 37, 55), "NHWC"), ((1, 44, 4, 37, 16), "NHCWB16")] +) +def test_ethosu_depthwise_conv2d_type_inference( + ifm_shape, + ifm_layout, + ofm_shape, + ofm_layout, +): + channels = 55 + kernel_shape = (3, 2) + padding = (0, 1, 2, 3) + strides = (1, 2) + dilation = (2, 1) + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + depthwise_conv2d = make_ethosu_depthwise_conv2d( + ifm, + channels, + kernel_shape, + padding, + strides, + dilation, + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + f = relay.Function([ifm], depthwise_conv2d) + f = run_opt_pass(f, relay.transform.InferType()) + assert tuple(f.body.checked_type.shape) == ofm_shape + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index 02c305387d45..cf845db2b43b 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -354,18 +354,17 @@ def create_mock(test_vec): max = np.iinfo(ifm_dtype).max min = np.iinfo(ifm_dtype).min values = np.random.randint(min, max, test_vec["shape"], ifm_dtype) - compressed_weights = vela_api.compress_weights( + vela_api.compress_weights( weights=values, weights_zp=test_vec["zero_point"], weights_layout=test_vec["layout"], ifm_bitdepth=ifm_bitdepth, block_depth=test_vec["block_depth"], dilation=test_vec["dilation"], - accel_type=test_vec["accel"], + accel_config=test_vec["accel"], is_depthwise=test_vec["is_depthwise"], ) return mock_npu_encode_weights - return None for tv in test_vecs: mock_obj = create_mock(tv) diff --git a/tests/python/contrib/test_hexagon/README.md b/tests/python/contrib/test_hexagon/README.md new file mode 100644 index 000000000000..a47c3438bf57 --- /dev/null +++ b/tests/python/contrib/test_hexagon/README.md @@ -0,0 +1,517 @@ + + + + + + + + + + + + + + + + + +Documents manual TE schedule to illustrate Hexagon operator slicing. + +# High Level Notes + +* Using float32 (for now) so that tests will pass on CPU +* Using global storage scope (for now) which means "cache" reads and writes from global, to global +* TIR is pending changes from the work-in-progress layout RFC + (https://github.com/apache/tvm-rfcs/pull/39) +* TIR has been hand-edited for context and clarity + * Added C-style comments + * Changed variable names + * Added spacing and line breaks +* Naming conventions + * Using input (instead of activation) + * Using filter (instead of weight, kernel) + * Using `k` to denote channel-out and `c` or `rc` (reduction channel) to denote channel-in + * Using `rh` and `rw` (reduction height / width) to denote filter height and width + +# Calling Convention + +TODO: Map this packed string to parameters +conv2d_packed_filter-1-1-0-float32-1-1-1-64-64-128-llvm + +# Baseline conv2d + +This is a baseline 1x1 conv2d schedule for Hexagon. + +## Command + +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-1-1-1-64-64-128-llvm]" + +## Parameters + +| Parameter | Value | +| --------- | ----------- | +| Batch | 1 | +| Filter | 1x1 | +| Spatial | 64x64 | +| Input Ch | 64 | +| Output Ch | 128 | +| Stride | 1 | +| Padding | 0 | +| Layout | NHWC8h8w32c | + +## Assumptions + +* Pattern matching for microkernels is not senstive to cache reads and writes between the outer height (ho) and outer width (wo) loops. + +## To Do + +* n/a + +## Annotated TIR + +``` +primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () + attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c + filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) + buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { + allocate(input.cache: Pointer(global float32), float32, [32768]), storage_scope = global; + allocate(filter.cache: Pointer(global float32), float32, [2048]), storage_scope = global; + allocate(output.cache: Pointer(global float32), float32, [16384]), storage_scope = global; + + for (ko.outer: int32, 0, 4) { + for (ho.outer: int32, 0, 8) { + + // input cache read + // NHWC -> NHWC8h8w32c (pending RFC) + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[(((((wo*4096) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[((((((ho.outer*32768) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } + + // filter cache read + for (co: int32, 0, 2) { + for (ci8: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (ci4: int32, 0, 4) { + filter.cache[((((co*1024) + (ci8*128)) + (ki*4)) + ci4)] = + (float32*)filter_pointer[(((((ko.outer*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] + } + } + } + } + + // compute + for (wo.c: int32, 0, 8) { + + // init output cache + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[((((wo.c*2048) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } + } + } + + // convolution + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + ( + (float32*)output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[(((((wo.c*4096) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * + (float32*)filter.cache[((((rc.outer*1024) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) + ) + } + } + } + } + } + } // end wo.c + + // cache write + for (wo: int32, 0, 8) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[((((((ho.outer*65536) + (wo*8192)) + (ko.outer*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((wo*2048) + (hi*256)) + (wi*32)) + ki)] + } + } + } + } + } // end ho.outer + } // end ko.outer +} +``` + +# Split on Channel Out and Height - "Full Output Slice" + +Adds new parameters `k_split` and `h_split` which creates a loop split on the outer channel out `ko` and height `ho` loops creating `outer` and `inner` loops for each split. The cache reads and writes are computed at `ho.outer` which means that cache allocation grow in proportion to `k_split` and `h_split` factors. + +The key changes in TIR versus the above are... + +1) Increased cache allocations: + +``` + // input cache grows by factor of h_split = 2 + allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // filter cache grows by factor of k_split = 2 + allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global; + + // output cache grows by factor of h_split * k_split = 4 + allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; +``` + +2) Outer loop splits using k_split and h_split factors + +``` + // ko.outer = outer loop split on ko using k_split factor + for (ko.outer: int32, 0, 2) { + // ho.outer = outer loop split on ho using h_split factor + for (ho.outer: int32, 0, 4) { +``` + +3) Inner loop splits in both cache read / write and compute schedules. This is taken from the compute schedule e.g. +``` + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { +``` + +## Command + +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-2-2-1-64-64-128-llvm]" + +## Parameters + +| Parameter | Value | +| --------- | ----------- | +| Batch | 1 | +| Filter | 1x1 | +| Spatial | 64x64 | +| Input Ch | 64 | +| Output Ch | 128 | +| Stride | 1 | +| Padding | 0 | +| Layout | NHWC8h8w32c | +| k_split | 2 | +| h_split | 2 | + +## Assumptions + +* n/a - With the loop splits on `ko` and `ho` the compute schedule is now over `ko.inner` `ho.inner` `wo` etc. This should fit the pattern matching for microkernels. + +## To Do + +* n/a + +## Annotated TIR + +``` +primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () + attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c + filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) + buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { + + // input cache grows by factor of h_split = 2 + allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // filter cache grows by factor of k_split = 2 + allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global; + + // output cache grows by factor of h_split * k_split = 4 + allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // ko.outer = outer loop split on ko using k_split factor + for (ko.outer: int32, 0, 2) { + // ho.outer = outer loop split on ho using h_split factor + for (ho.outer: int32, 0, 4) { + + // input cache read + // NHWC -> NHWC8h8w32c (pending RFC) + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } + } // end ho.inner + + // filter cache read + for (ko.inner: int32, 0, 2) { + for (co: int32, 0, 2) { + for (ci8: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (ci4: int32, 0, 4) { + filter.cache[(((((ko.inner*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] = + (float32*)filter_pointer[((((((ko.outer*4096) + (ko.inner*2048)) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] + } + } + } + } + } // end ko.inner + + // compute + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + + // init output cache + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } + } + } + + // convolution + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + ( + (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * + (float32*)filter.cache[(((((ko.c.inner*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) + ) + } + } + } + } + } + } // end wo.c + } // end ho.c.inner + } // end ko.c.inner + + // cache write + for (ko.inner: int32, 0, 2) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + } + } + } + } + } // end ho.inner + } // end ko.inner + } // end ho.outer + } // end ko.outer +} +``` + +# 3x3 conv2d (no padding) + +Change from a 1x1 filter to a 3x3 filter. The implication of this change is that `h_split + 1` rather than just `h_split` "full width" slices of the input are required to compute the output. This is due to the fact that the 3x3 filter will "fall off the bottom" of the input and thus the vertically adjacent "full width" slice must be prefetched into the input cache. + +The key changes in TIR versus the above are... + +1) Increased input cache size to hold the vertically adjacent slice + +``` + // input cache grows to hold vertically adjacent slice + allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; +``` + +2) Loop over `ho.inner` upper bound increased from `h_split` = 2 to `h_split + 1` = 3 + +``` + for (ho.outer: int32, 0, 4) { + for (ho.inner: int32, 0, 3) { + if (((ho.outer*2) + ho.inner) < 8) { +``` + +The `if` statement above indicates NOT to prefetch the vertically adjacent slice at the "bottom" of the input since it does not exist. + + +3) Increased filter cache size to hold 3x3 filter + +``` + // filter cache grows to hold larger 3x3 filter + allocate(filter.cache: Pointer(global float32), float32, [36864]), storage_scope = global; +``` + +4) Loops over `rh` and `rw` the kernel spatial dimensions: +``` + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { +``` + +## Command + +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-3-1-0-float32-2-2-1-64-64-128-llvm]" + +## Parameters + +| Parameter | Value | +| --------- | ----------- | +| Batch | 1 | +| Filter | 3x3 | +| Spatial | 64x64 | +| Input Ch | 64 | +| Output Ch | 128 | +| Stride | 1 | +| Padding | 0 | +| Layout | NHWC8h8w32c | +| h_split | 2 | + +## Assumptions + +* n/a + +## To Do + +There may be some opportunity to optimize cache reuse in this case. Consider the loops over `ho.outer` and `ho.inner` and the index calculation `ho.outer * 64k + ho.inner * 32k` into the input pointer: + +| ho.outer | ho.inner | ho.outer * 64k + ho.inner * 32k | +| -------- | -------- | ------------------------------------- | +| 0 | 0 | 0 | +| 0 | 1 | 32k | +| 0 | 2 | 64k (vertical adjacent slice loop 0) | +| 1 | 0 | 64k | +| 1 | 1 | 96k | +| 1 | 2 | 128k (vertical adjacent slice loop 1) | +| 2 | 0 | 128k | +| 2 | 1 | 160k | +| 2 | 2 | 192k (vertical adjacent slice loop 2) | +| 3 | 0 | 192k | +| 3 | 1 | 224k | +| 3 | 2 | (No vertical adjacent slice loop 3) | + +Noe that the vertically adjacent slice in loop N (i.e. the loop where `ho.outer` = N) is reused in loop N + 1. + +## Annotated TIR + +``` +primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () + attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c + filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 3, 3, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) + buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { + // input cache grows to hold vertically adjacent slice + allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; + // filter cache grows to hold larger 3x3 filter + allocate(filter.cache: Pointer(global float32), float32, [36864]), storage_scope = global; + allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + for (ko.outer: int32, 0, 2) { + for (ho.outer: int32, 0, 4) { + // input cache read + // NHWC -> NHWC8h8w32c (pending RFC) + for (ho.inner: int32, 0, 3) { + if (((ho.outer*2) + ho.inner) < 8) { + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } + } + } + // filter cache read + for (ko.inner: int32, 0, 2) { + for (co: int32, 0, 2) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (ci8: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (ci4: int32, 0, 4) { + filter.cache[(((((((ko.inner*18432) + (co*9216)) + (rh*3072)) + (rw*1024)) + (ci8*128)) + (ki*4)) + ci4)] = + (float32*)filter_pointer[((((((((ko.outer*36864) + (ko.inner*18432)) + (co*9216)) + (rh*3072)) + (rw*1024)) + (ci8*128)) + (ki*4)) + ci4)] + } + } + } + } // end rw + } // end rh + } + } + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } + } + } + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + ( + (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[((((((((floordiv((hi.c + rh), 8)*32768) + (ho.c.inner*32768)) + (floordiv((wi.c + rw), 8)*4096)) + (wo.c*4096)) + (rc.outer*2048)) + (floormod((hi.c + rh), 8)*256)) + (floormod((wi.c + rw), 8)*32)) + rc.inner)] * + (float32*)filter.cache[(((((((ko.c.inner*18432) + (rc.outer*9216)) + (rh*3072)) + (rw*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) + ) + } + } + } // end rw + } // end rh + } + } + } + } // end wo.c + } // end ho.c.inner + } // end ko.c.inner + for (ko.inner: int32, 0, 2) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + } + } + } + } + } // end ho.inner + } // end ko.inner + } // end ho.outer + } // end ko.outer +}``` \ No newline at end of file diff --git a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py b/tests/python/contrib/test_hexagon/test_conv2d_blocked.py index e0b7fb20ab8e..07696b51a327 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py +++ b/tests/python/contrib/test_hexagon/test_conv2d_blocked.py @@ -162,6 +162,8 @@ def conv2d_packed_filter( stride, padding, dtype, + k_split_factor, + h_split_factor, storage_scope="global", ): """ @@ -260,15 +262,50 @@ def compute(n, ho, wo, ko, hi, wi, ki): s[X_pad].compute_inline() s[X_packed].compute_inline() - # Perform scheduling - n, hid, wid, cid, hoff, woff, coff = s[Y].op.axis - slice = s[Y].fuse(wid, cid) + # cache read for the input / activation (X) Xl = s.cache_read(X_packed, storage_scope, [Y]) + Fl = s.cache_read(filt_packed, storage_scope, [Y]) + + # cache write for the output (Y) Yl = s.cache_write(Y, storage_scope) - s[Yl].compute_at(s[Y], hid) - n, hid, slice, hoff, woff, coff = s[Yl].op.axis - s[Xl].compute_at(s[Yl], slice) + ######################## + # cache write schedule # + ######################## + + # loop schedule corresponding with nhwc8h8w32c layout + # using k to represent output channel + n, ho, wo, ko, hi, wi, ki = s[Y].op.axis + + # loop split h and compute cache write at outer loop split + # to increase cache usage by factor of h_split_factor + koo, koi = s[Y].split(ko, factor=k_split_factor) + hoo, hoi = s[Y].split(ho, factor=h_split_factor) + s[Y].reorder(n, koo, hoo, koi, hoi, wo, hi, wi, ki) + s[Yl].compute_at(s[Y], hoo) + + #################### + # compute schedule # + #################### + + # loop schedule corresponding with nhwc8h8w32c layout + # using k to represent output channel + n, ho, wo, ko, hi, wi, ki = s[Yl].op.axis + + # reduction axes + # using rc to represent (reduction) input channel + rh, rw, rc = s[Yl].op.reduce_axis + + # split input channel by the block size + rco, rci = s[Yl].split(rc, factor=block_C) + + # loop split h and compute cache write at outer loop split + # to increase cache usage by factor of h_split_factor + koo, koi = s[Yl].split(ko, factor=k_split_factor) + hoo, hoi = s[Yl].split(ho, factor=h_split_factor) + s[Yl].reorder(n, koo, hoo, koi, hoi, wo, rco, hi, wi, ki, rci) + s[Xl].compute_at(s[Yl], hoo) + s[Fl].compute_at(s[Yl], hoo) binds = {} if storage_scope and storage_scope != "global": @@ -287,6 +324,8 @@ def conv2d_packed_filter_nhwhwc( stride, padding, dtype, + k_split_factor, + h_split_factor, storage_scope="global", ): """ @@ -299,7 +338,7 @@ def conv2d_packed_filter_nhwhwc( assert kernel_size == tuple(shape_oihw8i32o4i[2:4]) block_shape = get_block_shape() - block_H, block_W, _ = block_shape + block_H, block_W, block_C = block_shape shape = get_packed_activation_layout(shape_nhwc, block_shape, packed_C=False) logical_output_shape = get_conv2d_nhwc_shape( shape_nhwc, @@ -372,18 +411,66 @@ def compute(n, ho, wo, hi, wi, k): s[X_pad].compute_inline() s[X_packed].compute_inline() + # cache read for the input / activation (X) + Xl = s.cache_read(X_packed, storage_scope, [Y]) + Fl = s.cache_read(filt_packed, storage_scope, [Y]) + + # cache write for the output (Y) + Yl = s.cache_write(Y, storage_scope) + + ######################## + # cache write schedule # + ######################## + + # loop schedule corresponding with nhw8h8wc layout + # using k to represent output channel n, ho, wo, hi, wi, k = s[Y].op.axis - rh, rw, rc = s[Y].op.reduce_axis - rco, rci = s[Y].split(rc, factor=32) - s[Y].reorder(n, rco, wo, ho, k, hi, wi) - Xl = s.cache_read(X_packed, storage_scope, [Y]) - s[Xl].compute_at(s[Y], rco) + # split output channel by the block size + ko, ki = s[Y].split(k, factor=block_C) - ko, ki = s[Y].split(k, factor=32) - s[Y].reorder(n, rco, wo, ho, ko, hi, wi, ki) - Fl = s.cache_read(filt_packed, storage_scope, [Y]) - s[Fl].compute_at(s[Y], ko) + # loop split h and compute cache write at outer loop split + # to increase cache usage by factor of h_split_factor + koo, koi = s[Y].split(ko, factor=k_split_factor) + hoo, hoi = s[Y].split(ho, factor=h_split_factor) + s[Y].reorder(n, koo, hoo, koi, hoi, wo, hi, wi, ki) + s[Yl].compute_at(s[Y], hoo) + + #################### + # compute schedule # + #################### + + # loop schedule corresponding with nhw8h8wc layout + # using k to represent output channel + n, ho, wo, hi, wi, k = s[Yl].op.axis + + # reduction axes + # using rc to represent (reduction) input channel + rh, rw, rc = s[Yl].op.reduce_axis + + # split output & input channel by the block size + ko, ki = s[Yl].split(k, factor=block_C) + rco, rci = s[Yl].split(rc, factor=block_C) + + # loop split h and compute cache write at outer loop split + # to increase cache usage by factor of h_split_factor + koo, koi = s[Yl].split(ko, factor=k_split_factor) + hoo, hoi = s[Yl].split(ho, factor=h_split_factor) + s[Yl].reorder(n, koo, hoo, koi, hoi, wo, rco, hi, wi, ki, rci) + s[Xl].compute_at(s[Yl], hoo) + s[Fl].compute_at(s[Yl], hoo) + + ####################### + # cache read schedule # + ####################### + + # loop schedule corresponding with nhw8h8wc layout + # using k to represent output channel + n, ho, wo, hi, wi, c = s[Xl].op.axis + + # split intput channel by the block size + co, ci = s[Xl].split(c, factor=block_C) + s[Xl].reorder(n, ho, wo, co, hi, wi, ci) binds = {} if storage_scope and storage_scope != "global": @@ -397,13 +484,15 @@ def compute(n, ho, wo, hi, wi, k): class BaseConv2d: batch = tvm.testing.parameter(1) - in_size = tvm.testing.parameter(8, 56) - in_channel = tvm.testing.parameter(64) - out_channel = tvm.testing.parameter(64) - kernel = tvm.testing.parameter(3) + in_size = tvm.testing.parameter(8, 56, 64) + in_channel = tvm.testing.parameter(64, 128) + out_channel = tvm.testing.parameter(64, 128) + kernel = tvm.testing.parameter(1, 3) stride = tvm.testing.parameter(1) - pad = tvm.testing.parameter(1) + pad = tvm.testing.parameter(0, 1) dtype = tvm.testing.parameter("float32") + k_split_factor = tvm.testing.parameter(1, 2) + h_split_factor = tvm.testing.parameter(1, 2) class TestConv2dLogical(BaseConv2d): @@ -427,13 +516,37 @@ def test_conv2d(self, shape_nhwc, shape_oihw, kernel, stride, pad, dtype, target padding=(pad, pad, pad, pad), dtype=dtype, ) - return output, ref_output + + # nhwc8h8w32c -> nhwc + output = output.transpose(0, 1, 4, 2, 5, 3, 6).reshape( + output.shape[0], + output.shape[1] * output.shape[4], + output.shape[2] * output.shape[5], + output.shape[3] * output.shape[6], + ) + + # slice output to match ref_output shape + # e.g. 8x8 spatial 3x3 filter = 6x6 ref output + # but still 8x8 output given the blocked layout + output = output[ + 0 : ref_output.shape[0] : 1, + 0 : ref_output.shape[1] : 1, + 0 : ref_output.shape[2] : 1, + 0 : ref_output.shape[3] : 1, + ] + + if "int" in dtype: + tol = {"atol": 0, "rtol": 0} + elif dtype == "float32": + tol = {"rtol": 1e-4, "atol": 2e-4} + tvm.testing.assert_allclose(output, ref_output, **tol) class TestConv2dPackedFilter(BaseConv2d): conv2d_impl = tvm.testing.parameter(conv2d_packed_filter, conv2d_packed_filter_nhwhwc) @tvm.testing.parametrize_targets("llvm") + @pytest.mark.skip("Skip due to being flaky on i386.") def test_conv2d( self, conv2d_impl, @@ -445,6 +558,8 @@ def test_conv2d( pad, dtype, target, + k_split_factor, + h_split_factor, ): inputs = [ np.random.uniform(0, 255, size=shape_nhwc).astype(dtype), @@ -465,8 +580,45 @@ def test_conv2d( stride=(stride, stride), padding=(pad, pad, pad, pad), dtype=dtype, + k_split_factor=k_split_factor, + h_split_factor=h_split_factor, ) - return output, ref_output + + # nhwc8h8w32c + if len(output.shape) == 7: + # nhwc8h8w32c -> nhwc + output = output.transpose(0, 1, 4, 2, 5, 3, 6).reshape( + output.shape[0], + output.shape[1] * output.shape[4], + output.shape[2] * output.shape[5], + output.shape[3] * output.shape[6], + ) + + # nhwhwc + else: + # nhwhwc -> nhwc + output = output.transpose(0, 1, 3, 2, 4, 5).reshape( + output.shape[0], + output.shape[1] * output.shape[3], + output.shape[2] * output.shape[4], + output.shape[5], + ) + + # slice output to match ref_output shape + # e.g. 8x8 spatial 3x3 filter = 6x6 ref output + # but still 8x8 output given the blocked layout + output = output[ + 0 : ref_output.shape[0] : 1, + 0 : ref_output.shape[1] : 1, + 0 : ref_output.shape[2] : 1, + 0 : ref_output.shape[3] : 1, + ] + + if "int" in dtype: + tol = {"atol": 0, "rtol": 0} + elif dtype == "float32": + tol = {"rtol": 1e-4, "atol": 2e-4} + tvm.testing.assert_allclose(output, ref_output, **tol) if __name__ == "__main__": diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index ec512d7d714f..df4234e7e605 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -355,6 +355,34 @@ def load_vm(): assert_result_dict_holds(result_dict) +def test_conv1d(run_module): + def get_graph( + x_shape=((1, 3, 224)), + k_shape=(10, 3, 3), + groups=1, + padding=(1, 1), + strides=(1), + dilation=(1), + channels=None, + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + out = relay.nn.conv1d( + x, + kernel, + kernel_size=k_shape[2:3], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + channels=channels, + ) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + run_and_verify_func(get_graph(channels=10), run_module=run_module) + + def test_conv2d(run_module): def get_graph( x_shape=(1, 32, 8, 8), diff --git a/tests/python/contrib/test_vitis_ai/infrastructure.py b/tests/python/contrib/test_vitis_ai/infrastructure.py index e87d4f874630..578ac37da25b 100644 --- a/tests/python/contrib/test_vitis_ai/infrastructure.py +++ b/tests/python/contrib/test_vitis_ai/infrastructure.py @@ -99,7 +99,7 @@ def build_module( ), "Got {} Vitis-AI partitions, expected {}".format( partition_count, vitis_ai_partitions ) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() return relay.build(mod, target, params=params) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 2e4687fb7985..9d44d8f22f41 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -397,7 +397,7 @@ def test_compile_tflite_module_with_external_codegen_cmsisnn( tvmc_package = tvmc.compiler.compile_model( tvmc_model, - target=f"cmsis-nn, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 --executor=aot", + target=f"cmsis-nn, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 -executor=aot", output_format="mlf", package_path=output_file_name, pass_context_configs=["tir.disable_vectorize=true"], @@ -455,7 +455,7 @@ def test_compile_tflite_module_with_external_codegen_ethosu( tvmc_package = tvmc.compiler.compile_model( tvmc_model, - target=f"ethos-u -accelerator_config={accel_type}, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 --executor=aot", + target=f"ethos-u -accelerator_config={accel_type}, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 -executor=aot", output_format="mlf", package_path=output_file_name, pass_context_configs=["tir.disable_vectorize=true"], @@ -471,7 +471,11 @@ def test_compile_tflite_module_with_external_codegen_ethosu( for name in mlf_package.getnames() if re.match(r"\./codegen/host/src/\D+\d+\.c", name) ] - assert len(c_source_files) == 17 + # The number of c_source_files depends on the number of fused subgraphs that + # get offloaded to the NPU, e.g. conv2d->depthwise_conv2d->conv2d gets offloaded + # as a single subgraph if both of these operators are supported by the NPU. + # Currently there are two source files for CPU execution and two offload graphs + assert len(c_source_files) == 4 @mock.patch("tvm.relay.build") diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index 569c42020817..4d2fb56c5d4e 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -14,11 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os -import tarfile import pytest +import tvm from tvm.ir.module import IRModule from tvm.driver import tvmc @@ -229,3 +228,128 @@ def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): model_format="pytorch", shape_dict={"input": [1, 3, 224, 224]}, ) + + +def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + before = tvmc_model.mod + + expected_layout = "NCHW" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NHWC" + and node.attrs.dst_layout == "NCHW" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert any(layout_transform_calls), "Expected 'layout_transform NHWC->NCHW' not found" + + +def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50): + # some CI environments wont offer ONNX, so skip in case it is not present + pytest.importorskip("onnx") + + tvmc_model = tvmc.frontends.load_model(onnx_resnet50) + before = tvmc_model.mod + + expected_layout = "NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NHWC" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" + + +def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50): + # some CI environments wont offer Paddle, so skip in case it is not present + pytest.importorskip("paddle") + + tvmc_model = tvmc.frontends.load_model(paddle_resnet50, "paddle") + before = tvmc_model.mod + + expected_layout = "NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NHWC" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" + + +def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + before = tvmc_model.mod + + expected_layout = "NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NHWC" + and node.attrs.dst_layout == "NHWC" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" + + +def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50): + # some CI environments wont offer ONNX, so skip in case it is not present + pytest.importorskip("onnx") + + tvmc_model = tvmc.frontends.load_model(onnx_resnet50) + before = tvmc_model.mod + + expected_layout = "NCHW" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NCHW" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" diff --git a/tests/python/driver/tvmc/test_mlf.py b/tests/python/driver/tvmc/test_mlf.py index 0426f5678153..11306bd58848 100644 --- a/tests/python/driver/tvmc/test_mlf.py +++ b/tests/python/driver/tvmc/test_mlf.py @@ -27,7 +27,7 @@ @pytest.mark.parametrize( - "target,pass_configs", [["llvm", []], ["c --executor=aot", ["tir.disable_vectorize=1"]]] + "target,pass_configs", [["llvm", []], ["c -executor=aot", ["tir.disable_vectorize=1"]]] ) def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory, target, pass_configs): pytest.importorskip("tflite") @@ -114,7 +114,7 @@ def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_compile tflite_compiled_model_mlf = tflite_compile_model( tflite_mobilenet_v1_1_quant, - target="c --executor=aot", + target="c -executor=aot", output_format="mlf", pass_context_configs=["tir.disable_vectorize=1"], ) diff --git a/tests/python/driver/tvmc/test_pass_config.py b/tests/python/driver/tvmc/test_pass_config.py new file mode 100644 index 000000000000..d8ffd7d4d521 --- /dev/null +++ b/tests/python/driver/tvmc/test_pass_config.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from tvm.contrib.target.vitis_ai import vitis_ai_available +from tvm.driver import tvmc + +from tvm.driver.tvmc.common import TVMCException + + +def test_config_invalid_format(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"]) + + +def test_config_missing_from_tvm(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"]) + + +def test_config_unsupported_tvmc_config(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["tir.LoopPartition=value"]) + + +def test_config_empty(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs([""]) + + +def test_config_valid_config_bool(): + configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"]) + + assert len(configs) == 1 + assert "relay.backend.use_auto_scheduler" in configs.keys() + assert configs["relay.backend.use_auto_scheduler"] == True + + +@pytest.mark.skipif( + not vitis_ai_available(), + reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", +) +def test_config_valid_multiple_configs(): + configs = tvmc.common.parse_configs( + [ + "relay.backend.use_auto_scheduler=false", + "tir.detect_global_barrier=10", + "relay.ext.vitis_ai.options.build_dir=mystring", + ] + ) + + assert len(configs) == 3 + assert "relay.backend.use_auto_scheduler" in configs.keys() + assert configs["relay.backend.use_auto_scheduler"] == False + assert "tir.detect_global_barrier" in configs.keys() + assert configs["tir.detect_global_barrier"] == 10 + assert "relay.ext.vitis_ai.options.build_dir" in configs.keys() + assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring" diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_pass_list.py similarity index 97% rename from tests/python/driver/tvmc/test_common.py rename to tests/python/driver/tvmc/test_pass_list.py index 5cac6a1378a5..de50b04f415a 100644 --- a/tests/python/driver/tvmc/test_common.py +++ b/tests/python/driver/tvmc/test_pass_list.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import argparse import pytest from tvm.driver import tvmc -def test_common_parse_pass_list_str(): +def test_parse_pass_list_str(): assert [""] == tvmc.common.parse_pass_list_str("") assert ["FoldScaleAxis", "FuseOps"] == tvmc.common.parse_pass_list_str("FoldScaleAxis,FuseOps") diff --git a/tests/python/driver/tvmc/test_shape_parser.py b/tests/python/driver/tvmc/test_shape_parser.py new file mode 100644 index 000000000000..f49d89ac7c0f --- /dev/null +++ b/tests/python/driver/tvmc/test_shape_parser.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse + +import pytest + +from tvm.driver import tvmc + + +def test_shape_parser(): + # Check that a valid input is parsed correctly + shape_string = "input:[10,10,10]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10]} + + +def test_alternate_syntax(): + shape_string = "input:0:[10,10,10] input2:[20,20,20,20]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]} + + +@pytest.mark.parametrize( + "shape_string", + [ + "input:[10,10,10] input2:[20,20,20,20]", + "input: [10, 10, 10] input2: [20, 20, 20, 20]", + "input:[10,10,10],input2:[20,20,20,20]", + ], +) +def test_alternate_syntaxes(shape_string): + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} + + +def test_negative_dimensions(): + # Check that negative dimensions parse to Any correctly. + shape_string = "input:[-1,3,224,224]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + # Convert to strings to allow comparison with Any. + assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" + + +def test_multiple_valid_gpu_inputs(): + # Check that multiple valid gpu inputs are parsed correctly. + shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}" + assert str(shape_dict) == expected + + +def test_invalid_pattern(): + shape_string = "input:[a,10]" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + + +def test_invalid_separators(): + shape_string = "input:5,10 input2:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + + +def test_invalid_colon(): + shape_string = "gpu_0/data_0:5,10 :test:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + + +@pytest.mark.parametrize( + "shape_string", + [ + "gpu_0/data_0:5,10 /:10,10", + "gpu_0/data_0:5,10 data/:10,10", + "gpu_0/data_0:5,10 /data:10,10", + "gpu_0/invalid/data_0:5,10 data_1:10,10", + ], +) +def test_invalid_slashes(shape_string): + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + + +def test_dot(): + # Check dot in input name + shape_string = "input.1:[10,10,10]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input.1": [10, 10, 10]} diff --git a/tests/python/driver/tvmc/test_target.py b/tests/python/driver/tvmc/test_target.py new file mode 100644 index 000000000000..afb099f3add6 --- /dev/null +++ b/tests/python/driver/tvmc/test_target.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from tvm.driver import tvmc + +from tvm.driver.tvmc.common import TVMCException + + +def test_target_from_cli__error_duplicate(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("llvm, llvm") + + +def test_target_invalid_more_than_two_tvm_targets(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("cuda, opencl, llvm") + + +def test_target_from_cli__error_target_not_found(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("invalidtarget") + + +def test_target_from_cli__error_no_tvm_target(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("ethos-n77") + + +def test_target_two_tvm_targets(): + tvm_target, extra_targets = tvmc.common.target_from_cli( + "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu" + ) + + assert "opencl" in str(tvm_target) + assert "llvm" in str(tvm_target.host) + + # No extra targets + assert 0 == len(extra_targets) + + +def test_tokenize_target_with_opts(): + tokens = tvmc.common.tokenize_target("foo -opt1=value1 --flag, bar -opt2=value2") + expected_tokens = ["foo", "-opt1=value1", "--flag", ",", "bar", "-opt2=value2"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_plus_sign(): + tokens = tvmc.common.tokenize_target("foo -opt1=+value1 --flag, bar -opt2=test,+v") + expected_tokens = ["foo", "-opt1=+value1", "--flag", ",", "bar", "-opt2=test,+v"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_commas(): + tokens = tvmc.common.tokenize_target("foo -opt1=v,a,l,u,e,1 --flag") + expected_tokens = ["foo", "-opt1=v,a,l,u,e,1", "--flag"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_commas_and_single_quotes(): + tokens = tvmc.common.tokenize_target("foo -opt1='v, a, l, u, e', bar") + expected_tokens = ["foo", "-opt1='v, a, l, u, e'", ",", "bar"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_commas_and_double_quotes(): + tokens = tvmc.common.tokenize_target('foo -opt1="v, a, l, u, e", bar') + expected_tokens = ["foo", '-opt1="v, a, l, u, e"', ",", "bar"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_dashes(): + tokens = tvmc.common.tokenize_target("foo-bar1 -opt-1=t-e-s-t, baz") + expected_tokens = ["foo-bar1", "-opt-1=t-e-s-t", ",", "baz"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_parse_single_target_with_opts(): + targets = tvmc.common.parse_target("llvm -device=arm_cpu --system-lib") + + assert len(targets) == 1 + assert "device" in targets[0]["opts"] + assert "system-lib" in targets[0]["opts"] + + +def test_parse_multiple_target(): + targets = tvmc.common.parse_target("compute-library, llvm -device=arm_cpu --system-lib") + + assert len(targets) == 2 + assert "compute-library" == targets[0]["name"] + assert "llvm" == targets[1]["name"] + + +def test_parse_multiple_target_with_opts(): + targets = tvmc.common.parse_target("ethos-n77 -myopt=value, llvm -device=arm_cpu --system-lib") + + assert len(targets) == 2 + assert "ethos-n77" == targets[0]["name"] + assert "myopt" in targets[0]["opts"] + assert "value" == targets[0]["opts"]["myopt"] + assert "llvm" == targets[1]["name"] + + +def test_parse_quotes_and_separators_on_options(): + targets_no_quote = tvmc.common.parse_target("foo -option1=+v1.0x,+value,+bar") + targets_single_quote = tvmc.common.parse_target("foo -option1='+v1.0x,+value'") + targets_double_quote = tvmc.common.parse_target('foo -option1="+v1.0x,+value"') + + assert len(targets_no_quote) == 1 + assert "+v1.0x,+value,+bar" == targets_no_quote[0]["opts"]["option1"] + + assert len(targets_single_quote) == 1 + assert "+v1.0x,+value" == targets_single_quote[0]["opts"]["option1"] + + assert len(targets_double_quote) == 1 + assert "+v1.0x,+value" == targets_double_quote[0]["opts"]["option1"] diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py new file mode 100644 index 000000000000..f6942299b751 --- /dev/null +++ b/tests/python/driver/tvmc/test_target_options.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse + +import pytest + +from tvm.driver import tvmc +from tvm.driver.tvmc.common import TVMCException +from tvm.driver.tvmc.target import generate_target_args, reconstruct_target_args + + +def test_target_to_argparse(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, _ = parser.parse_known_args( + ["--target=llvm", "--target-llvm-mattr=+fp,+mve", "--target-llvm-mcpu=cortex-m3"] + ) + assert parsed.target == "llvm" + assert parsed.target_llvm_mcpu == "cortex-m3" + assert parsed.target_llvm_mattr == "+fp,+mve" + + +def test_mapping_target_args(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, _ = parser.parse_known_args(["--target=llvm", "--target-llvm-mcpu=cortex-m3"]) + assert reconstruct_target_args(parsed) == {"llvm": {"mcpu": "cortex-m3"}} + + +def test_target_recombobulation_single(): + tvm_target, _ = tvmc.common.target_from_cli("llvm", {"llvm": {"mcpu": "cortex-m3"}}) + + assert str(tvm_target) == "llvm -keys=cpu -link-params=0 -mcpu=cortex-m3" + + +def test_target_recombobulation_many(): + tvm_target, _ = tvmc.common.target_from_cli( + "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu", + {"llvm": {"mcpu": "cortex-m3"}, "opencl": {"max_num_threads": 404}}, + ) + + assert "-max_num_threads=404" in str(tvm_target) + assert "-device=mali" in str(tvm_target) + assert "-mtriple=aarch64-linux-gnu" in str(tvm_target.host) + assert "-mcpu=cortex-m3" in str(tvm_target.host) + + +def test_error_if_target_missing(): + with pytest.raises( + TVMCException, + match="Passed --target-opencl-max_num_threads but did not specify opencl target", + ): + tvmc.common.target_from_cli( + "llvm", + {"opencl": {"max_num_threads": 404}}, + ) diff --git a/tests/python/driver/tvmc/test_tracker.py b/tests/python/driver/tvmc/test_tracker.py new file mode 100644 index 000000000000..2ca0fae8f45e --- /dev/null +++ b/tests/python/driver/tvmc/test_tracker.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tvm.driver import tvmc + + +def test_tracker_host_port_from_cli__hostname_port(): + input_str = "1.2.3.4:9090" + expected_host = "1.2.3.4" + expected_port = 9090 + + actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + + assert expected_host == actual_host + assert expected_port == actual_port + + +def test_tracker_host_port_from_cli__hostname_port__empty(): + input_str = "" + + actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + + assert actual_host is None + assert actual_port is None + + +def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090(): + input_str = "1.2.3.4" + expected_host = "1.2.3.4" + expected_port = 9090 + + actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + + assert expected_host == actual_host + assert expected_port == actual_port diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py deleted file mode 100644 index bdfdb48ce6a0..000000000000 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ /dev/null @@ -1,413 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import argparse - -import pytest - -import tvm -from tvm.contrib.target.vitis_ai import vitis_ai_available -from tvm.driver import tvmc - -from tvm.driver.tvmc.common import TVMCException - - -def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant): - # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip("tflite") - - tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) - before = tvmc_model.mod - - expected_layout = "NCHW" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NHWC" - and node.attrs.dst_layout == "NCHW" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert any(layout_transform_calls), "Expected 'layout_transform NHWC->NCHW' not found" - - -def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50): - # some CI environments wont offer ONNX, so skip in case it is not present - pytest.importorskip("onnx") - - tvmc_model = tvmc.frontends.load_model(onnx_resnet50) - before = tvmc_model.mod - - expected_layout = "NHWC" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NCHW" - and node.attrs.dst_layout == "NHWC" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" - - -def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50): - # some CI environments wont offer Paddle, so skip in case it is not present - pytest.importorskip("paddle") - - tvmc_model = tvmc.frontends.load_model(paddle_resnet50, "paddle") - before = tvmc_model.mod - - expected_layout = "NHWC" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NCHW" - and node.attrs.dst_layout == "NHWC" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" - - -def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant): - # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip("tflite") - - tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) - before = tvmc_model.mod - - expected_layout = "NHWC" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NHWC" - and node.attrs.dst_layout == "NHWC" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" - - -def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50): - # some CI environments wont offer ONNX, so skip in case it is not present - pytest.importorskip("onnx") - - tvmc_model = tvmc.frontends.load_model(onnx_resnet50) - before = tvmc_model.mod - - expected_layout = "NCHW" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NCHW" - and node.attrs.dst_layout == "NCHW" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" - - -def test_tracker_host_port_from_cli__hostname_port(): - input_str = "1.2.3.4:9090" - expected_host = "1.2.3.4" - expected_port = 9090 - - actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) - - assert expected_host == actual_host - assert expected_port == actual_port - - -def test_tracker_host_port_from_cli__hostname_port__empty(): - input_str = "" - - actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) - - assert actual_host is None - assert actual_port is None - - -def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090(): - input_str = "1.2.3.4" - expected_host = "1.2.3.4" - expected_port = 9090 - - actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) - - assert expected_host == actual_host - assert expected_port == actual_port - - -def test_shape_parser(): - # Check that a valid input is parsed correctly - shape_string = "input:[10,10,10]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input": [10, 10, 10]} - # Check that multiple valid input shapes are parse correctly - shape_string = "input:[10,10,10] input2:[20,20,20,20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} - # Check that multiple valid input shapes with colons are parse correctly - shape_string = "input:0:[10,10,10] input2:[20,20,20,20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]} - # Check that alternate syntax parses correctly - shape_string = "input: [10, 10, 10] input2: [20, 20, 20, 20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} - shape_string = "input:[10,10,10],input2:[20,20,20,20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} - # Check that negative dimensions parse to Any correctly. - shape_string = "input:[-1,3,224,224]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - # Convert to strings to allow comparison with Any. - assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" - # Check that multiple valid gpu inputs are parsed correctly. - shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}" - assert str(shape_dict) == expected - - # Check that invalid pattern raises expected error. - shape_string = "input:[a,10]" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with invalid separators raises error. - shape_string = "input:5,10 input2:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with a invalid slash raises error. - shape_string = "gpu_0/data_0:5,10 /:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with a invalid colon raises error. - shape_string = "gpu_0/data_0:5,10 :test:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with a invalid slash raises error. - shape_string = "gpu_0/data_0:5,10 data/:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with a invalid slash raises error. - shape_string = "gpu_0/data_0:5,10 /data:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with invalid slashes raises error. - shape_string = "gpu_0/invalid/data_0:5,10 data_1:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - - -def test_target_from_cli__error_duplicate(): - with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("llvm, llvm") - - -def test_target_invalid_more_than_two_tvm_targets(): - with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("cuda, opencl, llvm") - - -def test_target_from_cli__error_target_not_found(): - with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("invalidtarget") - - -def test_target_from_cli__error_no_tvm_target(): - with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("ethos-n77") - - -def test_target_two_tvm_targets(): - tvm_target, extra_targets = tvmc.common.target_from_cli( - "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu" - ) - - assert "opencl" in str(tvm_target) - assert "llvm" in str(tvm_target.host) - - # No extra targets - assert 0 == len(extra_targets) - - -def test_tokenize_target_with_opts(): - tokens = tvmc.common.tokenize_target("foo -opt1=value1 --flag, bar -opt2=value2") - expected_tokens = ["foo", "-opt1=value1", "--flag", ",", "bar", "-opt2=value2"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_plus_sign(): - tokens = tvmc.common.tokenize_target("foo -opt1=+value1 --flag, bar -opt2=test,+v") - expected_tokens = ["foo", "-opt1=+value1", "--flag", ",", "bar", "-opt2=test,+v"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_commas(): - tokens = tvmc.common.tokenize_target("foo -opt1=v,a,l,u,e,1 --flag") - expected_tokens = ["foo", "-opt1=v,a,l,u,e,1", "--flag"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_commas_and_single_quotes(): - tokens = tvmc.common.tokenize_target("foo -opt1='v, a, l, u, e', bar") - expected_tokens = ["foo", "-opt1='v, a, l, u, e'", ",", "bar"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_commas_and_double_quotes(): - tokens = tvmc.common.tokenize_target('foo -opt1="v, a, l, u, e", bar') - expected_tokens = ["foo", '-opt1="v, a, l, u, e"', ",", "bar"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_dashes(): - tokens = tvmc.common.tokenize_target("foo-bar1 -opt-1=t-e-s-t, baz") - expected_tokens = ["foo-bar1", "-opt-1=t-e-s-t", ",", "baz"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_parse_single_target_with_opts(): - targets = tvmc.common.parse_target("llvm -device=arm_cpu --system-lib") - - assert len(targets) == 1 - assert "device" in targets[0]["opts"] - assert "system-lib" in targets[0]["opts"] - - -def test_parse_multiple_target(): - targets = tvmc.common.parse_target("compute-library, llvm -device=arm_cpu --system-lib") - - assert len(targets) == 2 - assert "compute-library" == targets[0]["name"] - assert "llvm" == targets[1]["name"] - - -def test_parse_multiple_target_with_opts(): - targets = tvmc.common.parse_target("ethos-n77 -myopt=value, llvm -device=arm_cpu --system-lib") - - assert len(targets) == 2 - assert "ethos-n77" == targets[0]["name"] - assert "myopt" in targets[0]["opts"] - assert "value" == targets[0]["opts"]["myopt"] - assert "llvm" == targets[1]["name"] - - -def test_parse_quotes_and_separators_on_options(): - targets_no_quote = tvmc.common.parse_target("foo -option1=+v1.0x,+value,+bar") - targets_single_quote = tvmc.common.parse_target("foo -option1='+v1.0x,+value'") - targets_double_quote = tvmc.common.parse_target('foo -option1="+v1.0x,+value"') - - assert len(targets_no_quote) == 1 - assert "+v1.0x,+value,+bar" == targets_no_quote[0]["opts"]["option1"] - - assert len(targets_single_quote) == 1 - assert "+v1.0x,+value" == targets_single_quote[0]["opts"]["option1"] - - assert len(targets_double_quote) == 1 - assert "+v1.0x,+value" == targets_double_quote[0]["opts"]["option1"] - - -def test_config_invalid_format(): - with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"]) - - -def test_config_missing_from_tvm(): - with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"]) - - -def test_config_unsupported_tvmc_config(): - with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs(["tir.LoopPartition=value"]) - - -def test_config_empty(): - with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs([""]) - - -def test_config_valid_config_bool(): - configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"]) - - assert len(configs) == 1 - assert "relay.backend.use_auto_scheduler" in configs.keys() - assert configs["relay.backend.use_auto_scheduler"] == True - - -@pytest.mark.skipif( - not vitis_ai_available(), - reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", -) -def test_config_valid_multiple_configs(): - configs = tvmc.common.parse_configs( - [ - "relay.backend.use_auto_scheduler=false", - "tir.detect_global_barrier=10", - "relay.ext.vitis_ai.options.build_dir=mystring", - ] - ) - - assert len(configs) == 3 - assert "relay.backend.use_auto_scheduler" in configs.keys() - assert configs["relay.backend.use_auto_scheduler"] == False - assert "tir.detect_global_barrier" in configs.keys() - assert configs["tir.detect_global_barrier"] == 10 - assert "relay.ext.vitis_ai.options.build_dir" in configs.keys() - assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring" diff --git a/tests/python/frontend/caffe/test_forward.py b/tests/python/frontend/caffe/test_forward.py index f4c0cd102340..233977d66066 100644 --- a/tests/python/frontend/caffe/test_forward.py +++ b/tests/python/frontend/caffe/test_forward.py @@ -763,6 +763,94 @@ def test_forward_TanH(): _test_tanh(np.random.rand(10).astype(np.float32)) +####################################################################### +# Embed +# ----------- + + +def _test_embed(data, **kwargs): + """One iteration of Embed""" + _test_op(data, L.Embed, "Embed", **kwargs) + + +def test_forward_Embed(): + k = 20 + data = [i for i in range(k)] + np.random.shuffle(data) + # dimension is 1 + data = np.asarray(data) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=True, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=False, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + # dimension is 2 + data = np.reshape(data, [4, 5]) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=True, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=False, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + # dimension is 3 + data = np.reshape(data, [2, 2, 5]) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=True, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=False, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + # dimension is 4 + data = np.reshape(data, [2, 2, 5, 1]) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=True, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=False, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + + ####################################################################### # Mobilenetv2 # ----------- diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 26bf58cbf384..4dfe89fe40e5 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -417,6 +417,17 @@ def test_forward_reuse_layers(self, keras): keras_model = keras.models.Model(data, z) verify_keras_frontend(keras_model) + def test_forward_lstm(self, keras): + data = keras.layers.Input(shape=(10, 32)) + rnn_funcs = [ + keras.layers.LSTM(16), + keras.layers.LSTM(16, return_sequences=True), + ] + for rnn_func in rnn_funcs: + x = rnn_func(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + def test_forward_rnn(self, keras): data = keras.layers.Input(shape=(1, 32)) rnn_funcs = [ @@ -613,6 +624,7 @@ def test_forward_nested_layers(self, keras): sut.test_forward_multi_inputs(keras=k) sut.test_forward_multi_outputs(keras=k) sut.test_forward_reuse_layers(keras=k) + sut.test_forward_lstm(keras=k) sut.test_forward_rnn(keras=k) sut.test_forward_vgg16(keras=k) sut.test_forward_vgg16(keras=k, layout="NHWC") diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 69bb44e360ff..dd1c77330986 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3970,6 +3970,7 @@ def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, ex make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales), ] input_names = ["X", "roi", "scales"] + if oshape != []: nodes.append( make_constant_node("sizes", onnx.TensorProto.INT64, (len(oshape),), oshape) @@ -4941,7 +4942,6 @@ def verify_eyelike(indata): "test_mvn", # This test fails llvm with a lowering error: "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded", - "test_qlinearmatmul_2D", "test_qlinearmatmul_3D", "test_range_float_type_positive_delta_expanded", "test_range_int32_type_negative_delta_expanded", @@ -4955,15 +4955,7 @@ def verify_eyelike(indata): "test_reduce_sum_keepdims_random", "test_reduce_sum_negative_axes_keepdims_example", "test_reduce_sum_negative_axes_keepdims_random", - "test_resize_downsample_sizes_cubic", - "test_resize_downsample_sizes_linear_pytorch_half_pixel", - "test_resize_downsample_sizes_nearest", "test_resize_tf_crop_and_resize", - "test_resize_upsample_sizes_cubic", - "test_resize_upsample_sizes_nearest", - "test_resize_upsample_sizes_nearest_ceil_half_pixel", - "test_resize_upsample_sizes_nearest_floor_align_corners", - "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", "test_rnn_seq_length", "test_round", "test_scan9_sum", diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index 1d64f947e68a..b274d178c9c2 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -80,9 +80,8 @@ def verify_model(func, input_data, rtol=1e-5, atol=1e-5): baseline_outputs = (baseline_outputs.numpy(),) mod, params = relay.frontend.from_paddle(baseline_model, input_shape_dict) - parms_num = min(len(input_names), len(mod["main"].params)) compiled_names = [] - for arg in mod["main"].params[:parms_num]: + for arg in mod["main"].params: assert arg.name_hint in input_names or arg.name_hint in params if arg.name_hint in input_names: compiled_names.append(arg.name_hint) @@ -383,31 +382,37 @@ def cusum3(inputs): @tvm.testing.uses_gpu def test_forward_conv(): - conv2d_input_shape = [1, 3, 10, 10] - class Conv2D1(nn.Layer): - def __init__(self): + def __init__(self, stride=1, padding=0, dilation=1, groups=1, padding_mode="zeros"): super(Conv2D1, self).__init__() - self.conv = nn.Conv2D(3, 6, 7, bias_attr=True) + self.conv = nn.Conv2D( + 3, + 6, + 3, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + ) self.softmax = nn.Softmax() @paddle.jit.to_static def forward(self, inputs): return self.softmax(self.conv(inputs)) - class Conv2D2(nn.Layer): - def __init__(self): - super(Conv2D2, self).__init__() - self.conv = nn.Conv2D(3, 6, 7, groups=3, bias_attr=False) - self.softmax = nn.Softmax() - - @paddle.jit.to_static - def forward(self, inputs): - return self.softmax(self.conv(inputs)) + input_shapes = [[1, 3, 10, 10], [1, 3, 12, 12]] - conv2d_input_data = paddle.rand(conv2d_input_shape, dtype="float32") - verify_model(Conv2D1(), input_data=conv2d_input_data) - verify_model(Conv2D2(), input_data=conv2d_input_data) + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Conv2D1(), input_data=input_data) + verify_model(Conv2D1(stride=2, padding="VALID", dilation=3), input_data=input_data) + verify_model(Conv2D1(stride=2, padding="SAME", dilation=3), input_data=input_data) + verify_model( + Conv2D1(stride=2, padding=3, dilation=3, padding_mode="replicate"), + input_data=input_data, + ) + verify_model(Conv2D1(stride=2, padding="SAME", dilation=2, groups=3), input_data=input_data) @tvm.testing.uses_gpu @@ -539,6 +544,26 @@ def full2(inputs): verify_model(full2, input_data=[input_data]) +@tvm.testing.uses_gpu +def test_forward_squeeze(): + class Squeeze(nn.Layer): + def __init__(self, axis=None): + super(Squeeze, self).__init__() + self.axis = axis + + @paddle.jit.to_static + def forward(self, inputs): + return paddle.squeeze(inputs, axis=self.axis) + + input_shapes = [[1, 1, 3, 1, 5], [5, 1, 6]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Squeeze(axis=None), input_data=input_data) + verify_model(Squeeze(axis=1), input_data=input_data) + input_data = paddle.rand([1], dtype="float32") + verify_model(Squeeze(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_ones_like(): @paddle.jit.to_static @@ -723,24 +748,55 @@ def forward(self, input1, input2): @tvm.testing.uses_gpu def test_forward_pool2d(): - @paddle.jit.to_static - def pool2d1(inputs): - return nn.functional.avg_pool2d(inputs, kernel_size=2, stride=2, padding=0) + class Pool2D1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.avg_pool2d(inputs, kernel_size=2, stride=2, padding=0) - @paddle.jit.to_static - def pool2d2(inputs): - return nn.functional.adaptive_avg_pool2d(inputs, output_size=[3, 3]) + class Pool2D2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.adaptive_avg_pool2d(inputs, output_size=[3, 3]) + + class Pool2D3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.avg_pool2d( + inputs, + kernel_size=3, + stride=1, + padding=[1, 1], + exclusive=False, + divisor_override=2.5, + ) + + input_shapes = [[1, 2, 8, 8], [1, 3, 10, 10]] + for input_shape in input_shapes: + input_data = paddle.uniform(shape=input_shape, dtype="float32", min=-1, max=1) + verify_model(Pool2D1(), input_data=input_data) + verify_model(Pool2D2(), input_data=input_data) + verify_model(Pool2D3(), input_data=input_data) - @paddle.jit.to_static - def pool2d3(inputs): - return nn.functional.max_pool2d( - inputs, kernel_size=2, stride=2, padding=0, return_mask=True - ) - input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype="float32", min=-1, max=1) - verify_model(pool2d1, input_data=input_data) - verify_model(pool2d2, input_data=input_data) - # verify_model(pool2d3, input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_pad3d(): + class Pad3D(nn.Layer): + def __init__(self, padding=0, mode="constant", value=0.0, data_format="NCDHW"): + super(Pad3D, self).__init__() + self.pad3d = paddle.nn.Pad3D(padding, mode=mode, value=value, data_format=data_format) + + @paddle.jit.to_static + def forward(self, inputs): + return self.pad3d(inputs) + + input_shapes = [[1, 2, 2, 5, 5], [1, 2, 2, 5, 9]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Pad3D(padding=2), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1]), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], value=0.3), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], mode="reflect"), input_data=input_data) + verify_model(Pad3D(padding=3, mode="replicate"), input_data=input_data) @tvm.testing.uses_gpu diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3a3889d5cfb7..0031f4143fab 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3962,5 +3962,35 @@ def test_fn(f, dim=None, keepdim=False): verify_model(test_fn(f, 0, keepdim=True), [torch.rand(4, 2).bool()]) +@tvm.testing.uses_gpu +def test_searchsorted(): + def test_fn(out_int32=False, right=False): + return lambda x, y: torch.searchsorted(x, y, out_int32=out_int32, right=right) + + sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + verify_model(test_fn(), [sorted_sequence, values]) + verify_model(test_fn(out_int32=True), [sorted_sequence[0], values[0]]) + verify_model(test_fn(right=True), [sorted_sequence, values]) + + sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + values = torch.tensor([[3, 6, 9], [4, 2, 7]]) + verify_model(test_fn(), [sorted_sequence_1d, values]) + + verify_model(test_fn(), [sorted_sequence_1d, torch.tensor(6)]) + + +@tvm.testing.uses_gpu +def test_bucketize(): + def test_fn(out_int32=False, right=False): + return lambda x, y: torch.bucketize(x, y, out_int32=out_int32, right=right) + + boundaries = torch.tensor([1, 3, 5, 7, 9]) + values = torch.tensor([3, 6, 9]) + + verify_model(test_fn(), [values, boundaries]) + verify_model(test_fn(out_int32=True, right=True), [values, boundaries]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 4a6f88417b9c..754976ca8c13 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -161,6 +161,7 @@ def run_tvm_graph( target="llvm", out_names=None, mode="graph_executor", + op_converter=relay.frontend.tflite.OperatorConverter, ): """Generic function to compile on relay and execute on tvm""" # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 @@ -185,7 +186,7 @@ def run_tvm_graph( dtype_dict[e] = input_data[i].dtype.name mod, params = relay.frontend.from_tflite( - tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict + tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict, op_converter=op_converter ) if mode in ["debug", "vm"]: @@ -3996,6 +3997,72 @@ def test_detection_postprocess(): ) +####################################################################### +# Custom Converter +# ---------------- + + +def test_custom_op_converter(): + """Test case for user-defined operator converter in TFLite frontend""" + + class DummyOperatorConverter(relay.frontend.tflite.OperatorConverter): + """Operator Converter for converting TFLite ops to relay ops""" + + def __init__(self, model, subgraph, exp_tab): + super(DummyOperatorConverter, self).__init__(model, subgraph, exp_tab) + self.allow_custom_ops = True + + convert_map_overwrite = {"SUB": self.convert_sub_dummy} + + self.convert_map.update(convert_map_overwrite) + + def convert_sub_dummy(self, op): + """Convert TFLite SUB""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + lhs_tensor = input_tensors[0] + rhs_tensor = input_tensors[1] + + lhs_expr = self.get_expr(lhs_tensor.tensor_idx) + rhs_expr = self.get_expr(rhs_tensor.tensor_idx) + + temp_expr = relay.op.negative(rhs_expr) + out = relay.op.add(lhs_expr, temp_expr) + + return out + + with tf.Graph().as_default(): + # Generate TFLite model for single addition + data = [ + np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), + np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)), + ] + in_data = [ + array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in_0"), + array_ops.placeholder(shape=data[1].shape, dtype="float32", name="in_1"), + ] + out = math_ops.subtract(in_data[0], in_data[1]) + in_name = [x[1] for x in zip(in_data, ("in_0:0", "in_1:0"))] + input_tensors = [x for x in in_data] + output_tensors = [out] + in_node = [0] * len(in_name) + for i in range(len(in_name)): + in_node[i] = in_name[i].split(":")[0] if ":" in in_name[i] else in_name[i] + + with tf.Session() as sess: + converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors) + tflite_model_buf = converter.convert() + in_data = [x[1] for x in zip(in_data, data)] + tvm_output_orig = run_tvm_graph(tflite_model_buf, in_data, in_node) + tvm_output_dummy = run_tvm_graph( + tflite_model_buf, in_data, in_node, op_converter=DummyOperatorConverter + ) + tvm.testing.assert_allclose( + np.squeeze(tvm_output_orig[0]), np.squeeze(tvm_output_dummy[0]), rtol=1e-5, atol=1e-5 + ) + + ####################################################################### # Mobilenet # --------- @@ -4621,6 +4688,9 @@ def test_prevent_tensorflow_dynamic_range(): # Detection_PostProcess test_detection_postprocess() + # Overwrite Converter + test_custom_op_converter() + # End to End test_forward_mobilenet_v1() test_forward_mobilenet_v2() diff --git a/tests/python/integration/test_lower.py b/tests/python/integration/test_lower.py index 690258c2fa3b..63733b05ab3f 100644 --- a/tests/python/integration/test_lower.py +++ b/tests/python/integration/test_lower.py @@ -33,9 +33,8 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: # body for blockIdx_x in T.thread_binding(0, 16, "blockIdx.x"): for blockIdx_y in T.thread_binding(0, 8, "blockIdx.y"): - with T.block([16, 8]) as [bx, by]: - T.bind(bx, blockIdx_x) - T.bind(by, blockIdx_y) + with T.block(): + bx, by = T.axis.remap("SS", [blockIdx_x, blockIdx_y]) shared_A = T.alloc_buffer([1024, 1024], "float16", scope="shared") shared_B = T.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_A = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") @@ -44,9 +43,9 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: for ty in T.thread_binding(0, 2, "threadIdx.y"): for tz in T.thread_binding(0, 2, "threadIdx.z"): for i, j in T.grid(2, 4): - with T.block([64, 64]) as [vi, vj]: - T.bind(vi, bx * 4 + ty * 2 + i) - T.bind(vj, by * 8 + tz * 4 + j) + with T.block(): + vi = T.axis.S(64, bx * 4 + ty * 2 + i) + vj = T.axis.S(64, by * 8 + tz * 4 + j) T.reads([]) T.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) C0 = T.match_buffer( @@ -74,23 +73,23 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: for tx in T.thread_binding(0, 32, "threadIdx.x"): for i0, j0 in T.grid(1, 4): for j1 in T.vectorized(0, 4): - with T.block([1024, 1024]) as [vi, vj]: - T.bind(vi, bx * 64 + ty * 32 + tx + i0) - T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + with T.block(): + vi = T.axis.S(1024, bx * 64 + ty * 32 + tx + i0) + vj = T.axis.S(1024, ko * 32 + tz * 16 + j0 * 4 + j1) shared_A[vi, vj + 8] = A[vi, vj] for i0, j0 in T.grid(2, 4): for j1 in T.vectorized(0, 4): - with T.block([1024, 1024]) as [vi, vj]: - T.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) - T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + with T.block(): + vi = T.axis.S(1024, by * 128 + ty * 64 + tx * 2 + i0) + vj = T.axis.S(1024, ko * 32 + tz * 16 + j0 * 4 + j1) shared_B[vi, vj + 8] = B[vi, vj] for ki in range(0, 2): for i in range(0, 2): - with T.block([64, 64]) as [vi, vk]: - T.bind(vi, bx * 4 + ty * 2 + i) - T.bind(vk, ko * 2 + ki) + with T.block(): + vi = T.axis.S(64, bx * 4 + ty * 2 + i) + vk = T.axis.S(64, ko * 2 + ki) T.reads( shared_A[ vi * 16 : vi * 16 + 16, @@ -142,9 +141,9 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: ) ) for j in range(0, 4): - with T.block([64, 64]) as [vj, vk]: - T.bind(vj, by * 8 + tz * 4 + j) - T.bind(vk, ko * 2 + ki) + with T.block(): + vj = T.axis.S(64, by * 8 + tz * 4 + j) + vk = T.axis.S(64, ko * 2 + ki) T.reads( shared_B[ vj * 16 : vj * 16 + 16, @@ -196,14 +195,10 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: ) ) for i, j in T.grid(2, 4): - with T.block([64, 64, T.reduce_axis(0, 64)]) as [ - vi, - vj, - vk, - ]: - T.bind(vi, bx * 4 + ty * 2 + i) - T.bind(vj, by * 8 + tz * 4 + j) - T.bind(vk, ko * 2 + ki) + with T.block(): + vi = T.axis.S(64, bx * 4 + ty * 2 + i) + vj = T.axis.S(64, by * 8 + tz * 4 + j) + vk = T.axis.R(64, ko * 2 + ki) T.reads( [ wmma_A[ @@ -258,9 +253,9 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: ) ) for i, j in T.grid(2, 4): - with T.block([64, 64]) as [vi, vj]: - T.bind(vi, bx * 4 + ty * 2 + i) - T.bind(vj, by * 8 + tz * 4 + j) + with T.block(): + vi = T.axis.S(64, bx * 4 + ty * 2 + i) + vj = T.axis.S(64, by * 8 + tz * 4 + j) T.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) s0 = T.var("int32") diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index ca097734a9eb..a40164ded941 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. import pytest +import numpy as np import tvm from tvm import te, topi -import numpy as np +from tvm.driver.build_module import schedule_to_module import tvm.testing import tvm.topi.testing @@ -532,10 +533,7 @@ def test_reduce_storage_reuse(): target = tvm.target.Target("cuda") def run_passes(sch, args): - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(sch, args) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) return tvm.transform.Sequential( [ diff --git a/tests/python/nightly/quantization/test_quantization_accuracy_for_vit.py b/tests/python/nightly/quantization/test_quantization_accuracy_for_vit.py index 8cecbf97c001..484ec23b369a 100644 --- a/tests/python/nightly/quantization/test_quantization_accuracy_for_vit.py +++ b/tests/python/nightly/quantization/test_quantization_accuracy_for_vit.py @@ -14,15 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm import os import sys -from tvm import relay -from tvm.relay import quantize as qtz import logging + +import pytest + +pytest.importorskip("onnx") + import onnx + +import tvm +from tvm import relay +from tvm.relay import quantize as qtz import tvm.testing -import mxnet as mx from test_quantization_accuracy import Config, get_val_data, eval_acc logging.basicConfig(level=logging.INFO) diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 746f595a4422..276cad375357 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -33,8 +33,10 @@ import tvm from tvm import relay +from tvm import te from tvm.contrib import utils, graph_executor -from tvm.relay.backend import compile_engine +from tvm.relay.backend import te_compiler +from tvm.relay.backend.te_compiler import TECompiler from tvm.relay.backend.utils import mangle_module_name from tvm.micro import export_model_library_format @@ -721,7 +723,6 @@ def compile_and_run( def generate_ref_data(mod, input_data, params=None, target="llvm"): """Generate reference data through executing the relay module""" - compile_engine.get().clear() with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params) diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index 22583eda4a40..7669d02cd536 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -41,7 +41,7 @@ def verify_func(func, data, ref_res, target_device=tvm.testing.enabled_targets() tvm.testing.assert_allclose(op_result.numpy(), ref_result, rtol=1e-5) else: tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() @tvm.testing.uses_gpu @@ -251,7 +251,8 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ verify_sparse_to_dense( [0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1] ) # floats - verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified + # default value not specified + verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) @pytest.mark.parametrize( diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index decddc1ef0a4..f42f7ad7ca69 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -23,6 +23,7 @@ from tvm import relay, te from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type +from tvm.topi.testing import searchsorted_ref from utils import ref_funcs from utils.assert_diagnostic import DiagnosticTesting @@ -2064,5 +2065,57 @@ def verify_scatter_nd(data_np, indices_np, updates_np, ref_res): verify_scatter_nd(data, indices, updates, out) +@tvm.testing.uses_gpu +def test_gather(): + def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, axis): + x = relay.var("x", relay.TensorType(data_shape, "float32")) + y = relay.var("y", relay.TensorType(indices_shape, "int32")) + z = relay.gather(x, axis, y) + + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + + data_np = np.random.uniform(size=data_shape_np).astype("float32") + indices_np = np.random.randint(low=0, high=2, size=indices_shape_np, dtype="int32") + + ref_res = tvm.topi.testing.gather_python(data_np, axis, indices_np) + check_result([data_np, indices_np], mod, [ref_res]) + + verify_gather((relay.Any(),), (relay.Any(),), (10,), (10,), 0) + verify_gather((2, 2), (2, relay.Any()), (2, 2), (2, 3), 1) + verify_gather((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3), 1) + verify_gather((relay.Any(), relay.Any()), (relay.Any(), relay.Any()), (2, 3), (1, 3), 0) + + +@tvm.testing.uses_gpu +def test_searchsorted(): + def verify_searchsorted( + sorted_sequence_shape, values_shape, sorted_sequence_shape_np, values_shape_np + ): + x = relay.var("x", relay.TensorType(sorted_sequence_shape, "float32")) + y = relay.var("y", relay.TensorType(values_shape, "float32")) + z = relay.searchsorted(x, y) + + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + + x_np = np.sort(np.random.uniform(size=sorted_sequence_shape_np).astype("float32"), axis=-1) + y_np = np.random.uniform(size=values_shape_np).astype("float32") + + ref_res = searchsorted_ref(x_np, y_np, False, "int32") + check_result([x_np, y_np], mod, [ref_res]) + + for shape_np, values_shape_np in zip([(8, 9, 10), (10,), (11,)], [(8, 9, 20), (5,), (8, 9, 7)]): + sorted_sequence_shape = (relay.Any(),) * len(shape_np) + values_shape = (relay.Any(),) * len(values_shape_np) + + verify_searchsorted( + sorted_sequence_shape, + values_shape, + shape_np, + values_shape_np, + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_ir_bind.py b/tests/python/relay/test_ir_bind.py index b179096a0528..0ab0122fa798 100644 --- a/tests/python/relay/test_ir_bind.py +++ b/tests/python/relay/test_ir_bind.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. """ test bind function.""" +import pytest import tvm from tvm import te from tvm import relay +from tvm import TVMError def test_bind_params(): @@ -34,5 +36,16 @@ def test_bind_params(): assert tvm.ir.structural_equal(zbinded, zexpected) +def test_bind_duplicated_params(): + a = relay.var("a", shape=(1,)) + aa = relay.var("a", shape=(1,)) + s = a + aa + func = relay.Function([a, aa], s) + + with pytest.raises(TVMError): + relay.build_module.bind_params_by_name(func, {"a": [1.0]}) + + if __name__ == "__main__": test_bind_params() + test_bind_duplicated_params() diff --git a/tests/python/relay/test_json_runtime.py b/tests/python/relay/test_json_runtime.py index ca792204c835..c6eb7531f635 100644 --- a/tests/python/relay/test_json_runtime.py +++ b/tests/python/relay/test_json_runtime.py @@ -26,7 +26,7 @@ from tvm import relay, runtime from tvm.contrib import utils from tvm.relay import transform -from tvm.relay.backend import compile_engine +from tvm.relay.backend import te_compiler from tvm.relay.build_module import bind_params_by_name from tvm.relay.op.contrib.register import get_pattern_table @@ -47,7 +47,7 @@ def check_result( return # Run the reference result - compile_engine.get().clear() + te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): json, lib, param = relay.build(ref_mod, target=target, params=params) rt_mod = tvm.contrib.graph_executor.create(json, lib, device) @@ -61,7 +61,7 @@ def check_result( ref_result = out.numpy() def check_vm_result(): - compile_engine.get().clear() + te_compiler.get().clear() with relay.build_config(opt_level=3): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() @@ -71,7 +71,7 @@ def check_vm_result(): tvm.testing.assert_allclose(out.numpy(), ref_result, rtol=tol, atol=tol) def check_graph_executor_result(): - compile_engine.get().clear() + te_compiler.get().clear() with relay.build_config(opt_level=3): json, lib, param = relay.build(mod, target=target, params=params) rt_mod = tvm.contrib.graph_executor.create(json, lib, device) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index eaddd33678df..754c9d1c4a74 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1422,7 +1422,8 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ verify_sparse_to_dense( [0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1] ) # floats - verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified + # default value not specified + verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # negative test cases # sparse indices should be ints @@ -1757,7 +1758,7 @@ def verify_func(target, dev, func, data, ref_res): tvm.testing.assert_allclose(op_result.numpy(), ref_result, rtol=1e-5) else: tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() def test_adv_index(target, dev, executor_kind): @@ -1970,7 +1971,8 @@ def calc_numpy_unique(data, is_sorted=False): uniq = uniq[order].astype(data.dtype) inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") counts = counts[order].astype("int32") - index = np.sort(index) # In unsorted case, need to sort the index of first occurence + # In unsorted case, need to sort the index of first occurence + index = np.sort(index) return [ uniq.astype(data.dtype), index.astype("int32"), diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index eb4eee379b08..c968c5a7f19f 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -773,7 +773,6 @@ def verify_roi_align( mode=mode, ) for target, dev in tvm.testing.enabled_targets(): - print("test on", target) op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( np_data, np_rois ) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index ea640c62dfeb..48c58dc2dc33 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -20,6 +20,7 @@ import numpy as np import tvm from tvm import relay +from tvm.topi.testing import searchsorted_ref import tvm.testing @@ -149,5 +150,28 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"): verify_topk(k, axis, ret_type, False, "int64", "float16") +@tvm.testing.uses_gpu +def test_searchsorted(): + def verify_searchsorted(right, dtype): + shape = (8, 9, 10) + values_shape = shape[:-1] + (10,) + sorted_sequence = relay.var("sorted_sequence", relay.TensorType(shape, "float32")) + values = relay.var("sorted_sequence", relay.TensorType(values_shape, "float32")) + out = relay.searchsorted(sorted_sequence, values, right, dtype) + func = relay.Function([sorted_sequence, values], out) + sorted_sequence_np = np.sort(np.random.randn(*shape).astype("float32"), axis=-1) + values_np = np.random.randn(*values_shape).astype("float32") + np_indices = searchsorted_ref(sorted_sequence_np, values_np, right, dtype) + + for target, dev in tvm.testing.enabled_targets(): + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + sorted_sequence_np, values_np + ) + np.testing.assert_equal(op_res.numpy(), np_indices) + + verify_searchsorted(False, "int32") + verify_searchsorted(True, "int64") + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 3310b6b2ed69..ab36f79c6ea7 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -507,12 +507,12 @@ def expected(): bias = relay.layout_transform(bias, src_layout="NHWC", dst_layout="NCHW") bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c") add = relay.add(y, bias) - y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW") - mean = relay.mean(y, axis=1, exclude=True) - var = relay.variance(y, axis=1, exclude=True) + mean = relay.mean(add, axis=[1, 4], exclude=True) + var = relay.variance(add, axis=[1, 4], exclude=True) denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05)) gamma = relay.var("gamma", shape=(16,)) - denom = denom * gamma + denom_c16c = denom * relay.layout_transform(gamma, src_layout="C", dst_layout="C16c") + denom = relay.layout_transform(denom_c16c, src_layout="C16c", dst_layout="C") denom_expand1 = relay.expand_dims(denom, axis=1, num_newaxis=2) denom_expand2 = relay.expand_dims(denom_expand1, axis=0) denom_nchwc16 = relay.layout_transform( @@ -520,7 +520,10 @@ def expected(): ) out = add * denom_nchwc16 beta = relay.var("beta", shape=(16,)) - numerator = (-mean) * denom + beta + numerator_c16c = (-mean) * denom_c16c + relay.layout_transform( + beta, src_layout="C", dst_layout="C16c" + ) + numerator = relay.layout_transform(numerator_c16c, src_layout="C16c", dst_layout="C") numerator_expand1 = relay.expand_dims(numerator, axis=1, num_newaxis=2) numerator_expand2 = relay.expand_dims(numerator_expand1, axis=0) numerator_nchwc16 = relay.layout_transform( @@ -1096,8 +1099,8 @@ def expected_nchw(): y = relay.nn.conv2d( y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c" ) - ret = relay.layout_transform(y, "NCHW16c", "NCHW") - ret = relay.sum(ret, axis=[1], keepdims=True) + ret = relay.sum(y, axis=[1, 4], keepdims=True) + ret = relay.layout_transform(ret, "NCHW1c", "NCHW") y = relay.Function(analysis.free_vars(ret), ret) return y @@ -1126,9 +1129,8 @@ def expected_nhwc(): y = relay.nn.conv2d( y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c" ) - ret = relay.layout_transform(y, "NCHW16c", "NCHW") - ret = relay.sum(ret, axis=[1], keepdims=True) - ret = relay.layout_transform(ret, "NCHW", "NHWC") + ret = relay.sum(y, axis=[1, 4], keepdims=True) + ret = relay.layout_transform(ret, "NCHW1c", "NHWC") y = relay.Function(analysis.free_vars(ret), ret) return y @@ -1397,28 +1399,54 @@ def expected(): assert tvm.ir.structural_equal(a, b) +def test_conv2d_strided_slice_packed_to_unpacked(): + """We do not support propagating through packed to unpacked layout""" + x_shape = (1, 1, 1, 1, 4) + w_shape = (9, 1, 3, 3, 4, 4) + + def before(): + x = relay.var("x", shape=x_shape) + weight = relay.var("weight", shape=w_shape) + y = relay.nn.conv2d( + x, + weight, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW4c", + kernel_layout="OIHW4i4o", + ) + y = relay.strided_slice(y, begin=[0, 0], end=[1, -1], strides=[1, 8]) + return relay.Function([x, weight], y) + + def expected(): + x = relay.var("x", shape=x_shape) + weight = relay.var("weight", shape=w_shape) + x_nchw = relay.layout_transform(x, src_layout="NCHW4c", dst_layout="NCHW") + weight_oihw = relay.layout_transform(weight, src_layout="OIHW4i4o", dst_layout="OIHW") + y = relay.nn.conv2d( + x_nchw, + weight_oihw, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NCHW4c") + y = relay.strided_slice(y, begin=[0, 0], end=[1, -1], strides=[1, 8]) + return relay.Function([x, weight], y) + + def alter_conv2d(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs["data_layout"] = "NCHW" + new_attrs["kernel_layout"] = "OIHW" + return relay.nn.conv2d(data, weight, **new_attrs) + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b) + + if __name__ == "__main__": - test_alter_op() - test_alter_return_none() - test_alter_layout() - test_alter_layout_dual_path() - test_alter_layout_lrn() - test_alter_layout_resnet() - test_alter_layout_broadcast_op() - test_alter_layout_broadcast_scalar_op() - test_alter_layout_scalar() - test_alter_layout_concatenate() - test_alter_layout_nchw_upsamping_op() - test_alter_layout_strided_slice() - test_alter_layout_depthwise_conv2d() - test_alter_layout_prelu() - test_alter_layout_pad() - test_alter_layout_pool() - test_alter_layout_sum() - test_alter_layout_nhwc_arm() - test_alter_layout_nhwc_int8_aarch64() - test_alter_op_with_global_var() - test_alter_op_dense() - test_alter_layout_strided_slice_axes_nhwc() - test_not_inplace_modify() - test_alter_op_dense_packed_data() + pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 9b4d154360b2..2359dcdf93d9 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -248,6 +248,61 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_conv_bias_pool_uses_specified_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.bias_add(y, bias, axis=3) + # a useless tuple, which will be eliminated + y = relay.Tuple([y])[0] + y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NHWC") + y = relay.cast(y, "int32") + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, "NHWC", "NCHW") + weight = relay.layout_transform(weight, "HWIO", "OIHW") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + + bias = relay.expand_dims(bias, axis=0, num_newaxis=3) + bias = relay.layout_transform(bias, "NHWC", "NCHW") + y = relay.add(y, bias) + # a useless tuple, which will be eliminated + y = relay.Tuple([y])[0] + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NHWC", out_layout="NHWC") + y = relay.cast(y, "int32") + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass( + a, + transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"], "nn.max_pool2d": ["NHWC"]}), + ) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + + def test_conv_concat_convert_layout(): def before(): x = relay.var("x", shape=(1, 56, 56, 64)) @@ -412,6 +467,139 @@ def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_deformable_conv_bias_pool_uses_specified_convert_layout(): + def before(N, CI, H, W, CO, KH, KW, layout): + if layout == "NCHW": + data_shape = (N, CI, H, W) + weight_shape = (CO, CI, KH, KW) + kernel_layout = "OIHW" + else: + data_shape = (N, H, W, CI) + weight_shape = (KH, KW, CI, CO) + kernel_layout = "HWIO" + bias_shape = (CO,) + + data = relay.var("data", shape=data_shape, dtype="float32") + offset = relay.var("offset") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + bias = relay.var("bias", shape=bias_shape, dtype="float32") + + y = relay.nn.deformable_conv2d( + data, + offset, + weight, + kernel_size=(KH, KW), + channels=CO, + data_layout=layout, + kernel_layout=kernel_layout, + ) + y = relay.nn.bias_add(y, bias, axis=-1 if layout == "NHWC" else 1) + y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout=layout) + y = relay.cast(y, "int32") + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout, max_pool_layout=None): + layout_map = {"src": {}, "dst": {}} + if src_layout == "NCHW": + nchw = layout_map["src"] + nhwc = layout_map["dst"] + else: + nchw = layout_map["dst"] + nhwc = layout_map["src"] + + nchw["data_layout"] = "NCHW" + nchw["data_shape"] = (N, CI, H, W) + nchw["offset_shape"] = (N, KH * KW * 2, OH, OW) + nchw["weight_shape"] = (CO, CI, KH, KW) + nchw["kernel_layout"] = "OIHW" + + nhwc["data_layout"] = "NHWC" + nhwc["data_shape"] = (N, H, W, CI) + nhwc["offset_shape"] = (N, OH, OW, KH * KW * 2) + nhwc["weight_shape"] = (KH, KW, CI, CO) + nhwc["kernel_layout"] = "HWIO" + + bias_shape = (CO,) + + data = relay.var("data", shape=layout_map["src"]["data_shape"], dtype="float32") + offset = relay.var("offset", shape=layout_map["src"]["offset_shape"], dtype="float32") + weight = relay.var("weight", shape=layout_map["src"]["weight_shape"], dtype="float32") + bias = relay.var("bias", shape=bias_shape, dtype="float32") + + data = relay.layout_transform( + data, layout_map["src"]["data_layout"], layout_map["dst"]["data_layout"] + ) + offset = relay.layout_transform( + offset, layout_map["src"]["data_layout"], layout_map["dst"]["data_layout"] + ) + weight = relay.layout_transform( + weight, layout_map["src"]["kernel_layout"], layout_map["dst"]["kernel_layout"] + ) + y = relay.nn.deformable_conv2d( + data, + offset, + weight, + kernel_size=(KH, KW), + channels=CO, + data_layout=layout_map["dst"]["data_layout"], + kernel_layout=layout_map["dst"]["kernel_layout"], + ) + if layout_map["src"]["data_layout"] == "NHWC": + bias = relay.expand_dims(bias, axis=0, num_newaxis=3) + else: + bias = relay.expand_dims(bias, axis=1, num_newaxis=2) + bias = relay.expand_dims(bias, axis=0) + bias = relay.layout_transform( + bias, layout_map["src"]["data_layout"], layout_map["dst"]["data_layout"] + ) + y = relay.add(y, bias) + y = relay.nn.relu(y) + if max_pool_layout != layout_map["dst"]["data_layout"]: + y = relay.layout_transform(y, layout_map["dst"]["data_layout"], max_pool_layout) + y = relay.nn.max_pool2d( + y, pool_size=(2, 2), layout=max_pool_layout, out_layout=max_pool_layout + ) + y = relay.cast(y, "int32") + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + # NHWC -> NCHW + a = before(1, 3, 224, 224, 32, 3, 3, "NHWC") + a = run_opt_pass( + a, + transform.ConvertLayout( + {"nn.deformable_conv2d": ["NCHW", "default"], "nn.max_pool2d": ["NHWC"]} + ), + ) + # - in the before() func, its last argument "NHWC" is also the layout of max_pool + b = run_opt_pass( + # max_pool has its own layout argument + expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NHWC", "NCHW", max_pool_layout="NHWC"), + transform.InferType(), + ) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + + # NCHW -> NHWC + a = before(1, 3, 224, 224, 32, 3, 3, "NCHW") + a = run_opt_pass( + a, + transform.ConvertLayout( + {"nn.deformable_conv2d": ["NHWC", "default"], "nn.max_pool2d": ["NCHW"]} + ), + ) + # - in the before() func, its last argument "NCHW" is also the layout of max_pool + b = run_opt_pass( + # max_pool has its own layout argument + expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NCHW", "NHWC", max_pool_layout="NCHW"), + transform.InferType(), + ) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + + def test_dual_path_convert_layout(): def before(): x = relay.var("x", shape=(1, 56, 56, 64)) @@ -702,6 +890,57 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_resnet_pool_uses_specified_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var("weight1", shape=(3, 3, 64, 32)) + weight2 = relay.var("weight2", shape=(1, 1, 64, 32)) + y = relay.nn.conv2d( + x, + weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.relu(y) + y2 = relay.nn.conv2d( + x, weight2, channels=32, kernel_size=(1, 1), data_layout="NHWC", kernel_layout="HWIO" + ) + y2 = relay.nn.relu(y2) + y = y + y2 + y = relay.nn.global_max_pool2d(y, layout="NHWC") + return relay.Function(analysis.free_vars(y), y) + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var("weight1", shape=(3, 3, 64, 32)) + weight2 = relay.var("weight2", shape=(1, 1, 64, 32)) + weight1 = relay.layout_transform(weight1, "HWIO", "OIHW") + weight2 = relay.layout_transform(weight2, "HWIO", "OIHW") + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.relu(y) + y2 = relay.nn.conv2d(x, weight2, channels=32, kernel_size=(1, 1)) + y2 = relay.nn.relu(y2) + y = y + y2 + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.nn.global_max_pool2d(y, layout="NHWC", out_layout="NHWC") + return relay.Function(analysis.free_vars(y), y) + + a = before() + a = run_opt_pass( + a, + transform.ConvertLayout( + {"nn.conv2d": ["NCHW", "default"], "nn.global_max_pool2d": ["NHWC"]} + ), + ) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + + def test_scalar_convert_layout(): def before(): x = relay.var("x", shape=(1, 56, 56, 64)) @@ -2039,5 +2278,54 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_conv_max_pool_uses_specified_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NCHW") + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + x = relay.layout_transform(x, "NCHW", "NHWC") + weight = relay.layout_transform(weight, "OIHW", "OHWI") + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="OHWI", + ) + y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NHWC", out_layout="NHWC") + y = relay.layout_transform(y, "NHWC", "NCHW") + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass( + a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "OHWI"], "nn.max_pool2d": ["NHWC"]}) + ) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 93cd6f791765..5aba6229c5e2 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -22,6 +22,7 @@ import numpy as np import tvm +from tvm.relay.backend import te_compiler import tvm.relay.testing import tvm.relay.op as reg from tvm import relay @@ -29,7 +30,6 @@ from tvm.relay import transform from tvm.relay.testing import byoc from tvm.contrib import utils -from tvm.relay.backend import compile_engine from tvm.relay.expr_functor import ExprMutator from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.op.contrib.register import get_pattern_table @@ -143,7 +143,7 @@ def update_lib(lib): return lib def check_vm_result(): - compile_engine.get().clear() + te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() @@ -157,7 +157,7 @@ def check_vm_result(): tvm.testing.assert_allclose(out.numpy(), ref, rtol=tol, atol=tol) def check_graph_executor_result(): - compile_engine.get().clear() + te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): json, lib, param = relay.build(mod, target=target, params=params) lib = update_lib(lib) @@ -508,7 +508,7 @@ def test_extern_dnnl_mobilenet(): ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( i_data, **params ) - compile_engine.get().clear() + te_compiler.get().clear() check_result(mod, {"data": i_data}, (1, 1000), ref_res.numpy(), tol=1e-5, params=params) @@ -950,7 +950,7 @@ def test_exec(mod, params, ref_mod, ref_params, out_shape): ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( i_data, **ref_params ) - compile_engine.get().clear() + te_compiler.get().clear() mod = get_partitoned_mod(mod, params, dnnl_patterns) diff --git a/tests/python/relay/test_pipeline_executor.py b/tests/python/relay/test_pipeline_executor.py index d9411c92c375..4a9b7eacdf65 100644 --- a/tests/python/relay/test_pipeline_executor.py +++ b/tests/python/relay/test_pipeline_executor.py @@ -16,6 +16,7 @@ # under the License. import pytest +import os import numpy as np import tvm import tvm.testing @@ -76,11 +77,11 @@ def get_manual_conf(mods, target): # The third output is the final output, the second output is for mod3, the first output # is for mod2 input. pipe_config1 = { - "mod_idx": 1, + "mod_idx": 0, "output": [ - {"output_idx": 0, "dependent": [{"mod_idx": 2, "input_name": "data_0"}]}, - {"output_idx": 1, "dependent": [{"mod_idx": 3, "input_name": "data_0"}]}, - {"output_idx": 2, "dependent": [{"mod_idx": 0, "input_name": "0"}]}, + {"output_idx": 0, "dependencies": [{"mod_idx": 1, "input_name": "data_0"}]}, + {"output_idx": 1, "dependencies": [{"mod_idx": 2, "input_name": "data_0"}]}, + {"output_idx": 2, "dependencies": [{"global_output_index": 0}]}, ], } mod_config[mods[0]] = { @@ -94,9 +95,9 @@ def get_manual_conf(mods, target): } pipe_config2 = { - "mod_idx": 2, + "mod_idx": 1, "output": [ - {"output_idx": 0, "dependent": [{"mod_idx": 3, "input_name": "data_1"}]}, + {"output_idx": 0, "dependencies": [{"mod_idx": 2, "input_name": "data_1"}]}, ], } mod_config[mods[1]] = { @@ -110,8 +111,8 @@ def get_manual_conf(mods, target): } pipe_config3 = { - "mod_idx": 3, - "output": [{"output_idx": 0, "dependent": [{"mod_idx": 0, "input_name": "1"}]}], + "mod_idx": 2, + "output": [{"output_idx": 0, "dependencies": [{"global_output_index": 1}]}], } mod_config[mods[2]] = { "pipeline": pipe_config3, @@ -128,7 +129,7 @@ def get_manual_conf(mods, target): def test_pipe_config_check(): # This function is used to trigger runtime error by applying wrong logic connection. - # Get the three pipeline modules here. + # Get three pipeline modules here. (mod1, mod2, mod3), dshape = get_mannual_mod() # The input or output name is illegal and expects a runtime error. @@ -179,10 +180,12 @@ def test_pipeline(): pipe_config = pipeline_executor.PipelineConfig() - # The global input named "data_0" will be connected to a input named "data_0" of mod1. + # The pipeline input named "data_0" will be connected to a input named "data_0" + # of mod1. pipe_config["input"]["data_0"].connect(pipe_config[mod1]["input"]["data_0"]) - # The global Input named "data_1" will be connected to a input named "data_1" of mod2. + # The pipeline Input named "data_1" will be connected to a input named "data_1" + # of mod2. pipe_config["input"]["data_1"].connect(pipe_config[mod2]["input"]["data_1"]) # The mod1 output[0] will be connected to a input named "data_0" of mod2. @@ -194,10 +197,10 @@ def test_pipeline(): # The mod2 output[2] will be connected to a input named "data_1" of mod3. pipe_config[mod2]["output"][0].connect(pipe_config[mod3]["input"]["data_1"]) - # The mod1 output[2] will be connected to global output[1]. + # The mod1 output[2] will be connected to pipeline output[0]. pipe_config[mod1]["output"][2].connect(pipe_config["output"]["0"]) - # The mod3 output[0] will be connected to global output[2]. + # The mod3 output[0] will be connected to pipeline output[1]. pipe_config[mod3]["output"][0].connect(pipe_config["output"]["1"]) # Print configueration (print(pipe_config)), the result looks like following. # @@ -231,9 +234,21 @@ def test_pipeline(): with tvm.transform.PassContext(opt_level=3): pipeline_mod_factory = pipeline_executor.build(pipe_config) + # Export the parameter configuration to a file. + directory_path = tvm.contrib.utils.tempdir().temp_dir + # If the directory does not exist, create it. + if not os.path.exists(directory_path): + os.makedirs(directory_path) + config_file_name = pipeline_mod_factory.export_library(directory_path) + + # Use the output of build to create and initialize PipelineModule. pipeline_module = pipeline_executor.PipelineModule(pipeline_mod_factory) assert pipeline_module + # Use the import function to create and initialize PipelineModule. + pipeline_module_test = pipeline_executor.PipelineModule.load_library(config_file_name) + assert pipeline_module_test.num_outputs == 2 + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_relay_te_compiler.py similarity index 93% rename from tests/python/relay/test_backend_compile_engine.py rename to tests/python/relay/test_relay_te_compiler.py index 092cae01f568..f8498ae83648 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_relay_te_compiler.py @@ -21,6 +21,7 @@ from tvm import relay from tvm import autotvm from tvm import topi +from tvm.relay.backend import te_compiler from tvm.relay.testing import run_infer_type from tvm.relay.testing.temp_op_attr import TempOpAttr import tvm.testing @@ -98,7 +99,7 @@ def _get_impls(dshape, wshape): weight = relay.var("wshape", shape=wshape) out = relay.nn.conv2d(data, weight, padding=(1, 1)) out = run_infer_type(out) - return relay.backend.compile_engine.get_valid_implementations( + return relay.backend.te_compiler.get_valid_implementations( relay.op.get("nn.conv2d"), out.attrs, [te.placeholder(dshape), te.placeholder(wshape)], @@ -121,7 +122,7 @@ def _select_impl(dshape, wshape, use_autotvm=False): weight = relay.var("wshape", shape=wshape) out = relay.nn.conv2d(data, weight, padding=(1, 1)) out = run_infer_type(out) - return relay.backend.compile_engine.select_implementation( + return relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), out.attrs, [te.placeholder(dshape), te.placeholder(wshape)], @@ -161,8 +162,8 @@ def _select_impl(dshape, wshape, use_autotvm=False): assert impl.name == "conv2d_1" -def test_compile_engine(): - engine = relay.backend.compile_engine.get() +def test_te_compiler(): + tec = relay.backend.te_compiler.get() def get_func(shape): x = relay.var("x", shape=shape) @@ -173,31 +174,30 @@ def get_func(shape): mod = relay.transform.InferType()(mod) return mod["main"] - z1 = engine.lower(get_func((10,)), "llvm") - z2 = engine.lower(get_func((10,)), "llvm") - z3 = engine.lower(get_func(()), "llvm") + z1 = tec.lower(get_func((10,)), "llvm") + z2 = tec.lower(get_func((10,)), "llvm") + z3 = tec.lower(get_func(()), "llvm") assert z1.same_as(z2) assert not z3.same_as(z1) if tvm.testing.device_enabled("cuda"): - z4 = engine.lower(get_func(()), "cuda") + z4 = tec.lower(get_func(()), "cuda") assert not z3.same_as(z4) # Test JIT target for target in ["llvm"]: dev = tvm.device(target) if tvm.testing.device_enabled(target): - f = engine.jit(get_func((10,)), target) + f = tec.jit(get_func((10,)), target) x = tvm.nd.array(np.ones(10).astype("float32"), device=dev) y = tvm.nd.empty((10,), device=dev) f(x, y) tvm.testing.assert_allclose(y.numpy(), x.numpy() * 3) - engine.dump() -# Note: Once compile engine is removed, we should keep this test so that +# Note: Once the te compiler is removed, we should keep this test so that # we make sure that opt_level=0 passes are being called correctly. def test_compile_placeholder_bypass(): - engine = relay.backend.compile_engine.get() + te_compiler = relay.backend.te_compiler.get() x = relay.var("x", shape=(2, 3)) y = relay.var("y", shape=(2, 3)) z = relay.var("z", shape=(2, 3)) @@ -264,7 +264,7 @@ def test_compile_nhwc_pack(): if __name__ == "__main__": test_get_valid_implementations() test_select_implementation() - test_compile_engine() + test_te_compiler() test_compile_placeholder_bypass() test_compile_injective_with_tuple() test_compile_tuple_dup() diff --git a/tests/python/relay/test_target_hooks.py b/tests/python/relay/test_target_hooks.py index 4d7a7fcdc15b..5856dc1e1c69 100644 --- a/tests/python/relay/test_target_hooks.py +++ b/tests/python/relay/test_target_hooks.py @@ -49,5 +49,28 @@ def test_tir_external_generation(check_result): check_result(func, inputs, (8,), x_data - y_data) +@pytest.mark.parametrize("check_result", [check_aot_executor_result, check_graph_executor_result]) +def test_runtime_module_generation(check_result): + shape = (8,) + x_data = np.random.randint(255, size=shape).astype("float32") + y_data = np.random.randint(255, size=shape).astype("float32") + inputs = {"x": x_data, "y": y_data} + + x0 = relay.var("x0", shape=shape, dtype="float32") + y0 = relay.var("y0", shape=shape, dtype="float32") + z = x0 + y0 + func = relay.Function([x0, y0], z) + func = set_external_func_attr(func, "example_target_hook", "replace_add_with_subtract") + # Test hook to trigger TIRToRuntime code generation + func = func.with_attr("tir_to_runtime", True) + + x = relay.var("x", shape=(8,), dtype="float32") + y = relay.var("y", shape=(8,), dtype="float32") + call = relay.Call(func, [x, y]) + func = IRModule.from_expr(call) + + check_result(func, inputs, (8,), x_data * y_data) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 42fe1a3cef3a..8ec41523f9dc 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -766,6 +766,19 @@ def test_vm_reshape_tensor(target, dev): check_result(target, dev, [x_np, y_np], x_np.reshape([8, 2, 8]), mod) +def test_vm_reshape_and_copy(target, dev): + """Make sure the compiler notices the reshape result shape is a literal and can use + the immediate-mode alloc_tensor instruction instead of alloc_tensor_reg.""" + x_np = np.random.uniform(size=(1, 1)).astype("float32") + x = relay.var("x", shape=(1, 1), dtype="float32") + mod = tvm.IRModule.from_expr(relay.Function([x], relay.copy(relay.reshape(x, [0, 1])))) + with tvm.transform.PassContext(opt_level=3): + exec = relay.vm.compile(mod, "llvm") + assert "alloc_tensor" in exec.bytecode + assert not "alloc_tensor_reg" in exec.bytecode + check_result(target, dev, [x_np], x_np.reshape([1, 1]), mod) + + def test_vm_reshape_tuple(target, dev, x_shape=(1, 4, 2), y_shape=(1, 2, 10)): tup = relay.var( "tup", @@ -963,4 +976,4 @@ def test_benchmark_end_to_end_rpc(): if __name__ == "__main__": import sys - sys.exit(pytest.main(sys.argv)) + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py new file mode 100644 index 000000000000..7b3976b7eb74 --- /dev/null +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing +import tvm.topi.testing +from tvm.topi.testing import searchsorted_ref +from tvm import te, topi + +topi_funcs = {"generic": topi.searchsorted, "cuda": topi.cuda.searchsorted} + + +def get_implementations(): + topi_func_generic = topi_funcs["generic"] + topi_func_cuda = topi_funcs["cuda"] + + return { + "generic": ( + lambda x, y, side, out_dtype: topi_func_generic(x, y, side, out_dtype), + topi.generic.schedule_extern, + ), + "cuda": ( + lambda x, y, side, out_dtype: topi_func_cuda(x, y, side, out_dtype), + topi.cuda.schedule_extern, + ), + "vulkan": ( + lambda x, y, side, out_dtype: topi_func_cuda(x, y, side, out_dtype), + topi.cuda.schedule_extern, + ), + } + + +@tvm.testing.parametrize_targets +def test_searchsorted(dev, target): + def verify_with_input(sorted_sequence_np, values_np, right): + sorted_sequence = te.placeholder(sorted_sequence_np.shape, dtype="float32") + values = te.placeholder(values_np.shape, dtype="float32") + out_dtype = "int32" + implementations = get_implementations() + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + + with tvm.target.Target(target): + indices = fcompute(sorted_sequence, values, right, out_dtype) + s = fschedule([indices]) + + func = tvm.build(s, [sorted_sequence, values, indices], target=target) + dev = tvm.device(target, 0) + + a = tvm.nd.array(sorted_sequence_np, dev) + b = tvm.nd.array(values_np, dev) + c = tvm.nd.array(np.zeros(values_np.shape, dtype=indices.dtype), dev) + func(a, b, c) + ref = searchsorted_ref(sorted_sequence_np, values_np, right, out_dtype) + np.testing.assert_equal(c.numpy(), ref) + + def verify(sequence_len, num_search, outer_axes, right, sorted_sequence_1d=False): + if sorted_sequence_1d: + sorted_sequence_shape = (sequence_len,) + else: + sorted_sequence_shape = outer_axes + (sequence_len,) + values_shape = outer_axes + (num_search,) + + verify_with_input( + np.sort(np.random.randn(*sorted_sequence_shape).astype("float32"), axis=-1), + np.random.randn(*values_shape).astype("float32"), + right, + ) + + verify(1024, 1000, (10, 5, 3), False) + verify(999, 2000, (10, 5, 3), True) + verify(1000, 1000, (), False) + verify(2001, 100, (500,), True) + verify(2001, 100, (500,), False, sorted_sequence_1d=True) + + # Check edge cases + for right in [True, False]: + sorted_sequence = np.array([1, 2, 3, 4, 5], dtype="float32") + verify_with_input(sorted_sequence, np.array([6], dtype="float32"), right) + verify_with_input(sorted_sequence, np.array([0], dtype="float32"), right) diff --git a/tests/python/unittest/test_autotvm_graph_tuner_utils.py b/tests/python/unittest/test_autotvm_graph_tuner_utils.py index 3f6d3980ee28..583bd366847c 100644 --- a/tests/python/unittest/test_autotvm_graph_tuner_utils.py +++ b/tests/python/unittest/test_autotvm_graph_tuner_utils.py @@ -20,6 +20,8 @@ # helps avoid topi arithmetic operator overloading issue: # https://github.com/apache/tvm/issues/3240 # TODO: restore the file name after this issue is resolved. +import pytest + import tvm from tvm import te @@ -34,6 +36,7 @@ bind_inputs, ) from tvm.autotvm.graph_tuner._base import OPT_OUT_OP +from tvm.autotvm.graph_tuner.utils.traverse_graph import _replace_device_with_tracing from tvm.relay.expr import Call, TupleGetItem, Tuple, Var @@ -57,7 +60,7 @@ def test_has_multiple_inputs(): target_ops = [relay.op.get("nn.conv2d")] node_list = [] node_dict = {} - expr2graph(net, target_ops, node_dict, node_list) + expr2graph(net, target_ops, node_dict, node_list, tvm.target.Target("llvm")) input_names = ["data"] verify_has_multiple_inputs(node_list, 2, input_names, False) verify_has_multiple_inputs(node_list, 4, input_names, False) @@ -79,7 +82,7 @@ def _count_node(node): relay.analysis.post_order_visit(mod["main"], _count_node) - expr2graph(mod["main"], target_ops, node_dict, node_list) + expr2graph(mod["main"], target_ops, node_dict, node_list, tvm.target.Target("llvm")) assert len(node_list) == len(op_name_list) for i, item in enumerate(zip(op_name_list, node_list)): op_name, node = item @@ -103,7 +106,7 @@ def test_get_direct_ancestor(): target_ops = [relay.op.get("nn.conv2d")] node_list = [] node_dict = {} - expr2graph(net, target_ops, node_dict, node_list) + expr2graph(net, target_ops, node_dict, node_list, tvm.target.Target("llvm")) visited_dict = {} input_names = ["data"] out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names) @@ -115,7 +118,7 @@ def test_get_direct_ancestor(): net = bind_inputs(net, {"data": (1, 16, 224, 224)}) node_list = [] node_dict = {} - expr2graph(net, target_ops, node_dict, node_list) + expr2graph(net, target_ops, node_dict, node_list, tvm.target.Target("llvm")) out = get_direct_ancestor(node_list, visited_dict, target_ops, 3, input_names) assert out == [0], "Output mismatch: expecting [0] but got %s." % str(out) @@ -134,7 +137,7 @@ def test_get_in_nodes(): input_names = ["data"] node_list = [] node_dict = {} - expr2graph(net, target_ops, node_dict, node_list) + expr2graph(net, target_ops, node_dict, node_list, tvm.target.Target("llvm")) out = get_in_nodes(node_list, target_ops, input_names) expected_out = {3: [0], 4: [3, 0], 7: [4]} diff_set = set(out) ^ set(expected_out) @@ -155,6 +158,20 @@ def test_get_out_nodes(): ) +def test_target_device_replacement(): + assert _replace_device_with_tracing("cuda") == "cuda -device=tracing" + assert ( + _replace_device_with_tracing("cuda -device=some_device -libs=cudnn") + == "cuda -device=tracing -libs=cudnn" + ) + assert ( + _replace_device_with_tracing("llvm -device=arm_cpu -arg=xxx") + == "llvm -device=tracing -arg=xxx" + ) + assert _replace_device_with_tracing("llvm -device=arm_cpu") == "llvm -device=tracing" + assert _replace_device_with_tracing("llvm -device=abc, def") == "llvm -device=tracing" + + if __name__ == "__main__": test_has_multiple_inputs() test_expr2graph() diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 62e68ab01ce5..9450a937a155 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -35,6 +35,7 @@ from tvm.topi.utils import get_const_tuple from tvm.topi.testing import conv2d_nchw_python +pytest.importorskip("tvm.micro.testing") from tvm.micro.testing import check_tune_log BUILD = True diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index 6502f0c67de6..fabf41705698 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -41,10 +41,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) - for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + for k in range(128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] diff --git a/tests/python/unittest/test_meta_schedule_arg_info.py b/tests/python/unittest/test_meta_schedule_arg_info.py index 7bedea9082d1..62dcb52f7415 100644 --- a/tests/python/unittest/test_meta_schedule_arg_info.py +++ b/tests/python/unittest/test_meta_schedule_arg_info.py @@ -28,10 +28,12 @@ def Matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 256), "float32") B = T.match_buffer(b, (256, 512), "float32") C = T.match_buffer(c, (128, 512), "float32") - with T.block([128, 256, T.reduce_axis(0, 512)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(128, 256, 512): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument diff --git a/tests/python/unittest/test_meta_schedule_builder.py b/tests/python/unittest/test_meta_schedule_builder.py index fa09a092c8c4..fb3fa135a9b8 100644 --- a/tests/python/unittest/test_meta_schedule_builder.py +++ b/tests/python/unittest/test_meta_schedule_builder.py @@ -47,10 +47,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @script.ir_module @@ -64,12 +66,16 @@ def matmul_relu( # pylint: disable=no-self-argument B = T.match_buffer(b, (1024, 1024), "float32") D = T.match_buffer(d, (1024, 1024), "float32") C = T.alloc_buffer((1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([1024, 1024], "relu") as [vi, vj]: - D[vi, vj] = T.max(C[vi, vj], 0.0) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(1024, 1024): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) @script.ir_module @@ -82,10 +88,12 @@ def batch_matmul( # pylint: disable=no-self-argument A = T.match_buffer(a, [16, 128, 128]) B = T.match_buffer(b, [16, 128, 128]) C = T.match_buffer(c, [16, 128, 128]) - with T.block([16, 128, 128, T.reduce_axis(0, 128)], "update") as [vn, vi, vj, vk]: - with T.init(): - C[vn, vi, vj] = 0.0 - C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + for n, i, j, k in T.grid(16, 128, 128, 128): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + with T.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py index cb39c91eaca4..121ec2fd480b 100644 --- a/tests/python/unittest/test_meta_schedule_database.py +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -41,10 +41,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @tvm.script.ir_module @@ -56,12 +58,16 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (16, 16), "float32") D = T.match_buffer(d, (16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([16, 16], "relu") as [vi, vj]: - D[vi, vj] = T.max(C[vi, vj], 0.0) + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(16, 16): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) # fmt: on diff --git a/tests/python/unittest/test_meta_schedule_runner.py b/tests/python/unittest/test_meta_schedule_runner.py index 9fb1e5ef19c1..46be12569c78 100644 --- a/tests/python/unittest/test_meta_schedule_runner.py +++ b/tests/python/unittest/test_meta_schedule_runner.py @@ -68,10 +68,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") - with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @tvm.script.ir_module @@ -83,12 +85,16 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (16, 16), "float32") D = T.match_buffer(d, (16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([16, 16], "relu") as [vi, vj]: - D[vi, vj] = T.max(C[vi, vj], 0.0) + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(16, 16): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) @tvm.script.ir_module @@ -99,10 +105,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, [16, 32, 32]) B = T.match_buffer(b, [16, 32, 32]) C = T.match_buffer(c, [16, 32, 32]) - with T.block([16, 32, 32, T.reduce_axis(0, 32)], "update") as [vn, vi, vj, vk]: - with T.init(): - C[vn, vi, vj] = 0.0 - C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + for n, i, j, k in T.grid(16, 32, 32, 32): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + with T.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] @tvm.script.ir_module @@ -113,8 +121,10 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, [32], "float32") B = T.match_buffer(b, [32], "float32") C = T.match_buffer(c, [32], "float32") - with T.block([32], "add") as [vi]: - C[vi] = A[vi] + B[vi] + for i in range(32): + with T.block("add"): + vi = T.axis.S(32, i) + C[vi] = A[vi] + B[vi] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index e12871391558..9b3ddfd7c789 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -45,10 +45,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (32, 32), "float32") B = T.match_buffer(b, (32, 32), "float32") C = T.match_buffer(c, (32, 32), "float32") - with T.block([32, 32, T.reduce_axis(0, 32)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(32, 32, 32): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 39bb1acf065f..3f7749ca9e2c 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -40,10 +40,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index a30409696543..4854aeb5f5aa 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -48,10 +48,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @tvm.script.ir_module @@ -63,12 +65,16 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (1024, 1024), "float32") D = T.match_buffer(d, (1024, 1024), "float32") C = T.alloc_buffer((1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([1024, 1024], "relu") as [vi, vj]: - D[vi, vj] = T.max(C[vi, vj], 0.0) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(1024, 1024): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) @tvm.script.ir_module @@ -79,10 +85,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, [16, 128, 128]) B = T.match_buffer(b, [16, 128, 128]) C = T.match_buffer(c, [16, 128, 128]) - with T.block([16, 128, 128, T.reduce_axis(0, 128)], "matmul") as [vn, vi, vj, vk]: - with T.init(): - C[vn, vi, vj] = 0.0 - C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + for n, i, j, k in T.grid(16, 128, 128, 128): + with T.block("matmul"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + with T.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks diff --git a/tests/python/unittest/test_meta_schedule_tune_context.py b/tests/python/unittest/test_meta_schedule_tune_context.py index 44bb949b925b..01a4379e5127 100644 --- a/tests/python/unittest/test_meta_schedule_tune_context.py +++ b/tests/python/unittest/test_meta_schedule_tune_context.py @@ -35,10 +35,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring diff --git a/tests/python/unittest/test_micro_project_api.py b/tests/python/unittest/test_micro_project_api.py index b5e2a57c122c..e319318656ef 100644 --- a/tests/python/unittest/test_micro_project_api.py +++ b/tests/python/unittest/test_micro_project_api.py @@ -25,6 +25,8 @@ import pytest import tvm + +pytest.importorskip("tvm.micro") from tvm.micro import project_api diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index ca6cb0181489..b67142b42358 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -52,15 +52,20 @@ def read_csv(report): @pytest.mark.skipif(not profiler_vm.enabled(), reason="VM Profiler not enabled") @tvm.testing.parametrize_targets def test_vm(target, dev): - mod, params = mlp.get_workload(1) - - exe = relay.vm.compile(mod, target, params=params) + dtype = "float32" + x = relay.var("x", shape=(relay.Any(), relay.Any()), dtype=dtype) + y = relay.var("y", shape=(relay.Any(), relay.Any()), dtype=dtype) + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], relay.add(x, y)) + exe = relay.vm.compile(mod, target) vm = profiler_vm.VirtualMachineProfiler(exe, dev) - data = np.random.rand(1, 1, 28, 28).astype("float32") - report = vm.profile(data, func_name="main") - assert "fused_nn_softmax" in str(report) + data = np.random.rand(28, 28).astype("float32") + report = vm.profile(data, data, func_name="main") + assert "fused_add" in str(report) assert "Total" in str(report) + assert "AllocTensorReg" in str(report) + assert "AllocStorage" in str(report) csv = read_csv(report) assert "Hash" in csv.keys() @@ -179,8 +184,15 @@ def test_report_serialization(): report = vm.profile(data, func_name="main") report2 = Report.from_json(report.json()) - # equality on reports compares pointers, so we compare the printed results instead. - assert str(report) == str(report2) + # Equality on reports compares pointers, so we compare the printed + # results instead. + + # Use .table() instead of str(), because str() includes aggregate + # and column summations whose values may be impacted by otherwise + # negligible conversion errors. (2 occurrences / 3000 trials) + assert report.table(aggregate=False, col_sums=False) == report2.table( + aggregate=False, col_sums=False + ) if __name__ == "__main__": diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 22aea8d1fcea..6e1fc815d66d 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -29,6 +29,7 @@ from tvm import rpc from tvm.contrib import utils, cc from tvm.rpc.tracker import Tracker +from tvm.rpc.proxy import Proxy if __name__ == "__main__": @@ -538,3 +539,46 @@ def test_rpc_tracker_request(): proc2.join() server.terminate() tracker.terminate() + + +@tvm.testing.requires_rpc +def test_rpc_tracker_via_proxy(): + """ + tracker + / \ + Host -- Proxy -- RPC server + """ + + device_key = "test_device" + + tracker_server = Tracker(port=9000, port_end=9100) + proxy_server = Proxy( + host=tracker_server.host, + port=8888, + port_end=8988, + tracker_addr=(tracker_server.host, tracker_server.port), + ) + + server1 = rpc.Server( + host=proxy_server.host, + port=proxy_server.port, + key=device_key, + tracker_addr=(tracker_server.host, tracker_server.port), + is_proxy=True, + ) + server2 = rpc.Server( + host=proxy_server.host, + port=proxy_server.port, + key=device_key, + tracker_addr=(tracker_server.host, tracker_server.port), + is_proxy=True, + ) + + client = rpc.connect_tracker(tracker_server.host, tracker_server.port) + remote1 = client.request(device_key, session_timeout=30) # pylint: disable=unused-variable + remote2 = client.request(device_key, session_timeout=30) # pylint: disable=unused-variable + + server2.terminate() + server1.terminate() + proxy_server.terminate() + tracker_server.terminate() diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 8c8d601672ac..5a1b33ae10b1 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -885,5 +885,21 @@ def check_llvm(use_file): check_llvm(use_file=False) +@tvm.testing.requires_llvm +def test_llvm_scalar_concat(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.decl_buffer((1,), "int32x2") + s = tvm.tir.Shuffle([x, y], [0, 1]) + f = tvm.tir.PrimFunc([x, y, z], z.vstore(0, s)) + + mod = tvm.ir.IRModule.from_expr(f.with_attr("global_symbol", "codegen_scalar_concat")) + + # This will crash in LLVM codegen if CodeGenLLVM::CreateVecConcat doesn't convert + # scalars to single-lane LLVM vectors. + with tvm.transform.PassContext(config={"tir.disable_assert": True}): + m = tvm.build(mod, [x, y, z], target="llvm") + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 987898001a1b..6b5c26d08b7b 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -54,10 +54,12 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: - with T.init(): - C[i, j] = 0.0 - C[i, j] += A[i, k] * B[j, k] + for i0, j0, k0 in T.grid(128, 128, 128): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] def test_matmul(): @@ -77,10 +79,14 @@ def tir_element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) B = T.alloc_buffer((128, 128)) - with T.block([128, 128]) as [i, j]: - B[i, j] = A[i, j] * 2.0 - with T.block([128, 128]) as [i, j]: - C[i, j] = B[i, j] + 1.0 + for i0, j0 in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) + B[i, j] = A[i, j] * 2.0 + for i0, j0 in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) + C[i, j] = B[i, j] + 1.0 def test_element_wise(): @@ -125,19 +131,21 @@ def tir_conv2d(a: T.handle, w: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [16, 32, 14, 14]) Apad = T.alloc_buffer([16, 16, 16, 16]) - with T.block([16, 16, 16, 16], "Apad") as [nn, cc, yy, xx]: - Apad[nn, cc, yy, xx] = T.if_then_else( - yy >= 1 and yy - 1 < 14 and xx >= 1 and xx - 1 < 14, - A[nn, cc, yy - 1, xx - 1], - 0.0, - dtype="float32", - ) - with T.block( - [16, 32, 14, 14, T.reduce_axis(0, 16), T.reduce_axis(0, 3), T.reduce_axis(0, 3)], "B" - ) as [nn, ff, yy, xx, rc, ry, rx]: - with T.init(): - B[nn, ff, yy, xx] = 0.0 - B[nn, ff, yy, xx] += Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff] + for n, c, y, x in T.grid(16, 16, 16, 16): + with T.block("Apad"): + nn, cc, yy, xx = T.axis.remap("SSSS", [n, c, y, x]) + Apad[nn, cc, yy, xx] = T.if_then_else( + yy >= 1 and yy - 1 < 14 and xx >= 1 and xx - 1 < 14, + A[nn, cc, yy - 1, xx - 1], + 0.0, + dtype="float32", + ) + for n, f, y, x, kc, ky, kx in T.grid(16, 32, 14, 14, 16, 3, 3): + with T.block("B"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [n, f, y, x, kc, ky, kx]) + with T.init(): + B[nn, ff, yy, xx] = 0.0 + B[nn, ff, yy, xx] += Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff] def test_conv2d(): @@ -163,9 +171,11 @@ def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> B1 = T.match_buffer(b1, (m, n)) for i0, i1 in T.grid(m, n): - with T.block([m, n], "B.v0") as [i, j]: + with T.block("B.v0"): + i, j = T.axis.remap("SS", [i0, i1]) B0[i, j] = A0[i, j] + 2.0 - with T.block([m, n], "B.v1") as [i, j]: + with T.block("B.v1"): + i, j = T.axis.remap("SS", [i0, i1]) B1[i, j] = A1[i, j] * 3.0 @@ -193,7 +203,7 @@ def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) # body - with T.block([], "C"): + with T.block("C"): T.reads([A[0:128, 0:128], B[0:128, 0:128]]) T.writes([C[0:128, 0:128]]) T.evaluate( @@ -251,10 +261,12 @@ def tir_reordered_matmul(c: T.handle, a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: - with T.init(): - C[i, j] = 0.0 - C[i, j] += A[i, k] * B[j, k] + for i0, j0, k0 in T.grid(128, 128, 128): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] def test_arg_order(): diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index bc4bc4f56e19..ca3ab3aade98 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np + import tvm from tvm import te -import numpy as np +from tvm.driver.build_module import schedule_to_module def test_schedule0(): @@ -26,11 +28,8 @@ def test_schedule0(): A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") s = te.create_schedule(A1.op) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None) - assert isinstance(func, tvm.tir.PrimFunc) + mod = schedule_to_module(s, [A, A1]) + assert isinstance(mod["main"], tvm.tir.PrimFunc) def test_schedule1(): @@ -42,12 +41,9 @@ def test_schedule1(): s = te.create_schedule(A1.op) xo, xi = s[A1].split(A1.op.axis[0], 8) s[A1].pragma(xo, "auto_unroll_max_step", 10) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None) - assert isinstance(func, tvm.tir.PrimFunc) + mod = schedule_to_module(s, [A, A1]) + assert isinstance(mod["main"], tvm.tir.PrimFunc) def test_schedule2(): @@ -60,11 +56,9 @@ def test_schedule2(): s = te.create_schedule(A2.op) xo, xi = s[A2].split(A2.op.axis[0], 8) s[A1].compute_at(s[A2], xo) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) - assert isinstance(func, tvm.tir.PrimFunc) + + mod = schedule_to_module(s, [A, A2]) + assert isinstance(mod["main"], tvm.tir.PrimFunc) def test_schedule_scan(): diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 1aae8cdd03e1..1a0dfd09a2df 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -25,17 +25,23 @@ def buffer_load_store_func(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128), "float32") C = T.alloc_buffer((128, 128), "float32") D = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128]) as [i, j]: - A[i, j] = T.float32(0) - with T.block([32, 32, T.reduce_axis(0, 32)]) as [i, j, k]: - with T.init(): + for ii, jj in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [ii, jj]) + A[i, j] = T.float32(0) + for i0, j0, k0 in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + for ii, jj in T.grid(4, 4): + B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in T.grid(4, 4): - B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] - for ii, jj in T.grid(4, 4): - for kk in range(0, 4): - B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] - for kk in range(0, 4): - B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += ( + D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + ) @T.prim_func @@ -43,7 +49,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [16, 16], "float32") C = T.match_buffer(c, [16, 16], "float32") - with T.block([]): + with T.block(): T.reads([]) T.writes(B[0:16, 0:16]) A = T.allocate([256], "float32", "global") @@ -56,9 +62,8 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, T.float32(0), dtype="handle")) for i, j in T.grid(16, 16): - with T.block([16, 16]) as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] @@ -72,16 +77,20 @@ def lca_is_func_root(a: T.handle) -> None: def match_buffer_func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") - with T.block([8, 8], "block") as [vi, vj]: - T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) - T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) - B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) - with T.block([16, 16], "AAA") as [i, j]: - AA = T.match_buffer(A[i, j], ()) - AA[()] = 1.0 - T.evaluate(B0.data) - T.evaluate(B1.data) + for i, j in T.grid(8, 8): + with T.block("block"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + for ii, jj in T.grid(16, 16): + with T.block("AAA"): + vii, vjj = T.axis.remap("SS", [ii, jj]) + AA = T.match_buffer(A[vii, vjj], ()) + AA[()] = 1.0 + T.evaluate(B0.data) + T.evaluate(B1.data) def test_buffer_load_store(): diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index e3a63c325434..4ea35c0a2d6c 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -27,57 +27,65 @@ def func() -> None: B = T.alloc_buffer((128, 128), "float32") C = T.alloc_buffer((128, 128), "float32") D = T.alloc_buffer((128, 128), "float32") - with T.block([]): + with T.block(): # Need add read/write region manually to avoid triggering block access region detector T.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]]) T.writes([A[0:12, 0:12]]) for i, j in T.grid(8, 8): A[i, j] = B[0, 0] + C[0, 0] - with T.block([2, 2]) as [vi, vj]: - T.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) - T.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) - for i, j in T.grid(4, 4): - A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12] + for i, j in T.grid(2, 2): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) + T.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) + for i, j in T.grid(4, 4): + A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12] T.evaluate(D.data) @T.prim_func def match_buffer_func() -> None: - with T.block([], "root"): + with T.block("root"): A = T.alloc_buffer((128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") T.reads([]) T.writes([]) # Need add read/write region manually to avoid triggering block access region detector - with T.block([8, 8], "block") as [vi, vj]: - T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) - T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - AA = T.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) - B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) - B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) - with T.block([16, 16], "AAA") as [i, j]: - T.reads([]) - T.writes(AA[i, j]) - AAA = T.match_buffer(AA[i, j], ()) - AAA[()] = 1.0 - T.evaluate(B0.data) - T.evaluate(B1.data) + for i, j in T.grid(8, 8): + with T.block("block"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + AA = T.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) + B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = T.match_buffer( + B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8) + ) + for ii, jj in T.grid(16, 16): + with T.block("AAA"): + vii, vjj = T.axis.remap("SS", [ii, jj]) + T.reads([]) + T.writes(AA[vii, vjj]) + AAA = T.match_buffer(AA[vii, vjj], ()) + AAA[()] = 1.0 + T.evaluate(B0.data) + T.evaluate(B1.data) @T.prim_func def opaque_block_func() -> None: - with T.block([], "root"): + with T.block("root"): A = T.alloc_buffer((16, 16), "float32") B = T.alloc_buffer((16, 16), "float32") T.reads([]) T.writes([]) # Need add read/write region manually to avoid triggering block access region detector for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes([B[i, 0:16]]) for j in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 @@ -88,8 +96,8 @@ def opaque_access_func() -> None: A = T.alloc_buffer([1024]) B = T.alloc_buffer([1024]) for i in T.serial(0, 8): - with T.block([8]) as [v]: - T.bind(v, i) + with T.block(): + v = T.axis.S(8, i) T.reads([A[v * 128 : v * 128 + 128]]) T.writes([B[v * 128 : v * 128 + 128]]) T.evaluate( diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 7129275aebcd..5ca9cf0da3c9 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -39,7 +39,7 @@ def buffer_load_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16)) C = T.match_buffer(c, (16, 16)) for i, j, k in T.grid(4, 16, 8): - with T.block([]): + with T.block(): T.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) T.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) sub_A = T.match_buffer( @@ -55,7 +55,7 @@ def transformed_buffer_load_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16)) C = T.match_buffer(c, (16, 16)) for i, j, k in T.grid(4, 16, 8): - with T.block([]): + with T.block(): T.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) T.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) for ii, kk in T.grid(4, 2): @@ -72,7 +72,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (32, 64, 128)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(2, 64, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) sub_A = T.match_buffer( @@ -93,7 +93,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: ) ) for i, j, k in T.grid(64, 2, 8): - with T.block([]): + with T.block(): Bs_0 = T.var("int32") Bs_1 = T.var("int32") T.reads([]) @@ -122,7 +122,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (32, 64, 128)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(2, 64, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) T.evaluate( @@ -137,7 +137,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: ) ) for i, j, k in T.grid(64, 2, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) T.evaluate( @@ -157,7 +157,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: def high_dim_opaque_access(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64)) for i, j, k in T.grid(16, 2, 4): - with T.block([]): + with T.block(): As_0 = T.var("int32") As_1 = T.var("int32") T.reads([]) @@ -185,7 +185,7 @@ def high_dim_opaque_access(a: T.handle) -> None: def transformed_high_dim_opaque_access(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64)) for i, j, k in T.grid(16, 2, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) T.evaluate( @@ -205,7 +205,7 @@ def transformed_high_dim_opaque_access(a: T.handle) -> None: def high_dim_opaque_access_with_source_strides(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): - with T.block([]): + with T.block(): As_0 = T.var("int32") As_1 = T.var("int32") T.reads([]) @@ -233,7 +233,7 @@ def high_dim_opaque_access_with_source_strides(a: T.handle) -> None: def transformed_high_dim_opaque_access_with_source_strides(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) T.evaluate( @@ -254,7 +254,7 @@ def recursive_match(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (64, 64, 64)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(64, 4, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes( [ @@ -276,7 +276,7 @@ def recursive_match(a: T.handle, b: T.handle) -> None: offset_factor=1, ) for jj, kk in T.grid(4, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes( [ @@ -317,7 +317,7 @@ def transformed_recursive_match(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (64, 64, 64)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(64, 4, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes( [ @@ -326,7 +326,7 @@ def transformed_recursive_match(a: T.handle, b: T.handle) -> None: ] ) for jj, kk in T.grid(4, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes( [ @@ -362,7 +362,7 @@ def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) -> None: A = T.match_buffer(a, (n * m, m)) B = T.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) Bs_0 = T.var("int32") @@ -392,7 +392,7 @@ def transformed_symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) A = T.match_buffer(a, (n * m, m)) B = T.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) for ii, jj in T.grid(m, m): @@ -416,7 +416,7 @@ def rank0_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (8, 8)) B = T.match_buffer(b, (8, 8)) for i, j in T.grid(8, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i, j], B[i, j]]) sub_A = T.match_buffer(A[i, j], (), offset_factor=1) @@ -440,7 +440,7 @@ def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (8, 8)) B = T.match_buffer(b, (8, 8)) for i, j in T.grid(8, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i, j], B[i, j]]) A[i, j] = 1 @@ -461,7 +461,7 @@ def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None: def fail_match_load(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 8): - with T.block([]): + with T.block(): T.reads(A[i, j]) T.writes([]) sub_A = T.match_buffer(A[i, j], ()) @@ -472,7 +472,7 @@ def fail_match_load(a: T.handle) -> None: def fail_match_store(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, j]) sub_A = T.match_buffer(A[i, j], ()) @@ -483,7 +483,7 @@ def fail_match_store(a: T.handle) -> None: def fail_buffer_bind(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): - with T.block([]): + with T.block(): stride = T.var("int32") sub_A = T.match_buffer( A[i, j * 4 : j * 4 + 4], (1, 4), strides=[stride, stride], offset_factor=1 @@ -496,7 +496,7 @@ def fail_buffer_bind(a: T.handle) -> None: def fail_match_func_param(a: T.handle, m: T.handle, n: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): - with T.block([]): + with T.block(): sub_A = T.match_buffer(A[i, j * 4 : j * 4 + 4], (1, 4), strides=[m, n], offset_factor=1) for jj in range(0, 4): sub_A[i, j * 4 + jj] = 1 diff --git a/tests/python/unittest/test_tir_schedule_block_scope.py b/tests/python/unittest/test_tir_schedule_block_scope.py index 2182c7b9f449..ad789a010745 100644 --- a/tests/python/unittest/test_tir_schedule_block_scope.py +++ b/tests/python/unittest/test_tir_schedule_block_scope.py @@ -32,10 +32,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -44,10 +48,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -58,9 +64,11 @@ def war_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index ff5b61a135eb..853f44affe5d 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -33,10 +33,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -45,20 +49,23 @@ def access_under_scope(b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([8, 8], "scope") as [i, j]: - for x, y in T.grid(16, 16): - with T.block([128, 128], "A") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A[vi, vj] = 1.0 - for x, y in T.grid(16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - B[vi, vj] = A[vi, vj] + 1.0 - - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + for i0, j0 in T.grid(8, 8): + with T.block("scope"): + i, j = T.axis.remap("SS", [i0, j0]) + for x, y in T.grid(16, 16): + with T.block("A"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A[vi, vj] = 1.0 + for x, y in T.grid(16, 16): + with T.block("B"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B[vi, vj] = A[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -68,76 +75,82 @@ def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: C = T.match_buffer(c, (128, 128), dtype="float16") D = T.match_buffer(d, (128, 128), dtype="float16") - with T.block([128, 128], "load_store") as [vi, vj]: - T.reads(A[vi, vj]) - T.writes(D[vi, vj]) - D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) - with T.block([8, 8], "opaque") as [vi, vj]: - T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.evaluate( - T.tvm_load_matrix_sync( - B.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A.data, - vi * 2048 + vj * 16, + for i, j in T.grid(128, 128): + with T.block("load_store"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(D[vi, vj]) + D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) + for i, j in T.grid(8, 8): + with T.block("opaque"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.evaluate( + T.tvm_load_matrix_sync( + B.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), 128, - 1, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) + ) + for i, j in T.grid(8, 8): + with T.block("match_buffer"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = T.match_buffer( + A[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + C0 = T.match_buffer( + C[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, ) - ) - with T.block([8, 8], "match_buffer") as [vi, vj]: - T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A0 = T.match_buffer( - A[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - C0 = T.match_buffer( - C[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - T.evaluate( - T.tvm_load_matrix_sync( - C0.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A0.data, - A0.elem_offset, - A0.strides[0], - 1, + T.evaluate( + T.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) ) - ) @T.prim_func @@ -147,15 +160,16 @@ def func_multi_consumer() -> None: C = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block([128], "A") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A"): + vi = T.axis.S(128, i * 16 + j) A[vi] = 1.0 for j in T.grid(16): - with T.block([128], "B") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("B"): + vi = T.axis.S(128, i * 16 + j) B[vi] = A[vi] + 1.0 for i in T.grid(128): - with T.block([128], "C") as [vi]: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = A[vi] @@ -163,12 +177,18 @@ def func_multi_consumer() -> None: def func_multi_producer() -> None: A = T.alloc_buffer((128)) B = T.alloc_buffer((128)) - with T.block([128], "A0") as [vi]: - A[vi] = 1.0 - with T.block([128], "A1") as [vi]: - A[vi] = 2.0 - with T.block([128], "B") as [vi]: - B[vi] = A[vi] + for i in range(128): + with T.block("A0"): + vi = T.axis.S(128, i) + A[vi] = 1.0 + for i in range(128): + with T.block("A1"): + vi = T.axis.S(128, i) + A[vi] = 2.0 + for i in range(128): + with T.block("B"): + vi = T.axis.S(128, i) + B[vi] = A[vi] ########## Expected function after cache_read ########## @@ -181,14 +201,22 @@ def cache_read_elementwise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) A_global = T.alloc_buffer((128, 128)) B_local = T.alloc_buffer((128, 128), scope="local") - with T.block([128, 128], "A_global") as [vi, vj]: - A_global[vi, vj] = A[vi, vj] - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A_global[vi, vj] * 2.0 - with T.block([128, 128], "B_local") as [vi, vj]: - B_local[vi, vj] = B[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B_local[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("A_global"): + vi, vj = T.axis.remap("SS", [i, j]) + A_global[vi, vj] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A_global[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B_local"): + vi, vj = T.axis.remap("SS", [i, j]) + B_local[vi, vj] = B[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B_local[vi, vj] + 1.0 @T.prim_func @@ -198,27 +226,33 @@ def cache_read_under_scope(b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) A_global = T.alloc_buffer((128, 128)) - with T.block([8, 8], "scope") as [i, j]: - A_local = T.alloc_buffer((128, 128), scope="local") - for x, y in T.grid(16, 16): - with T.block([128, 128], "A") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A[vi, vj] = 1.0 - for x, y in T.grid(16, 16): - with T.block([128, 128], "A_local") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A_local[vi, vj] = A[vi, vj] - for x, y in T.grid(16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - B[vi, vj] = A_local[vi, vj] + 1.0 - with T.block([128, 128], "A_global") as [vi, vj]: - A_global[vi, vj] = A[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A_global[vi, vj] * 2.0 + for i0, j0 in T.grid(8, 8): + with T.block("scope"): + i, j = T.axis.remap("SS", [i0, j0]) + A_local = T.alloc_buffer((128, 128), scope="local") + for x, y in T.grid(16, 16): + with T.block("A"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A[vi, vj] = 1.0 + for x, y in T.grid(16, 16): + with T.block("A_local"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A_local[vi, vj] = A[vi, vj] + for x, y in T.grid(16, 16): + with T.block("B"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B[vi, vj] = A_local[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("A_global"): + vi, vj = T.axis.remap("SS", [i, j]) + A_global[vi, vj] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A_global[vi, vj] * 2.0 @T.prim_func @@ -229,78 +263,86 @@ def cache_read_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) D = T.match_buffer(d, (128, 128), dtype="float16") A_global = T.alloc_buffer((128, 128), dtype="float16") - with T.block([128, 128], "A_global") as [vi, vj]: - A_global[vi, vj] = A[vi, vj] - with T.block([128, 128], "load_store") as [vi, vj]: - T.reads(A_global[vi, vj]) - T.writes(D[vi, vj]) - D.data[vi * 128 + vj] = T.load("float16", A_global.data, vi * 128 + vj) - with T.block([8, 8], "opaque") as [vi, vj]: - T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.evaluate( - T.tvm_load_matrix_sync( - B.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A_global.data, - vi * 2048 + vj * 16, + for i, j in T.grid(128, 128): + with T.block("A_global"): + vi, vj = T.axis.remap("SS", [i, j]) + A_global[vi, vj] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("load_store"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A_global[vi, vj]) + T.writes(D[vi, vj]) + D.data[vi * 128 + vj] = T.load("float16", A_global.data, vi * 128 + vj) + for i, j in T.grid(8, 8): + with T.block("opaque"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.evaluate( + T.tvm_load_matrix_sync( + B.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A_global.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), 128, - 1, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) + ) + for i, j in T.grid(8, 8): + with T.block("match_buffer"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = T.match_buffer( + A_global[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, ) - ) - with T.block([8, 8], "match_buffer") as [vi, vj]: - T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A0 = T.match_buffer( - A_global[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - C0 = T.match_buffer( - C[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - T.evaluate( - T.tvm_load_matrix_sync( - C0.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A0.data, - A0.elem_offset, - A0.strides[0], - 1, + C0 = T.match_buffer( + C[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + T.evaluate( + T.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) ) - ) @T.prim_func @@ -311,20 +353,21 @@ def cache_read_multi_consumer() -> None: A_global = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block([128], "A") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A"): + vi = T.axis.S(128, i * 16 + j) A[vi] = 1.0 for j in T.grid(16): - with T.block([128], "A") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A"): + vi = T.axis.S(128, i * 16 + j) A_global[vi] = A[vi] for j in T.grid(16): - with T.block([128], "B") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("B"): + vi = T.axis.S(128, i * 16 + j) B[vi] = A_global[vi] + 1.0 for i in T.grid(128): - with T.block([128], "C") as [vi]: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = A_global[vi] @@ -335,14 +378,22 @@ def continuous_cache_read(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) B_shared = T.alloc_buffer((128, 128), scope="shared") B_local = T.alloc_buffer((128, 128), scope="local") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "B_shared") as [vi, vj]: - B_shared[vi, vj] = B[vi, vj] - with T.block([128, 128], "B_local") as [vi, vj]: - B_local[vi, vj] = B_shared[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B_local[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B_shared"): + vi, vj = T.axis.remap("SS", [i, j]) + B_shared[vi, vj] = B[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B_local"): + vi, vj = T.axis.remap("SS", [i, j]) + B_local[vi, vj] = B_shared[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B_local[vi, vj] + 1.0 ########## Expected function after cache_write ########## @@ -355,14 +406,22 @@ def cache_write_elementwise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) B_global = T.alloc_buffer((128, 128), scope="local") C_local = T.alloc_buffer((128, 128)) - with T.block([128, 128], "B_global") as [vi, vj]: - B_global[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = B_global[vi, vj] - with T.block([128, 128], "C_local") as [vi, vj]: - C_local[vi, vj] = B[vi, vj] + 1.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = C_local[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B_global"): + vi, vj = T.axis.remap("SS", [i, j]) + B_global[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = B_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C_local"): + vi, vj = T.axis.remap("SS", [i, j]) + C_local[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = C_local[vi, vj] @T.prim_func @@ -372,33 +431,39 @@ def cache_write_under_scope(b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) A_global = T.alloc_buffer((128, 128)) - with T.block([8, 8], "scope") as [i, j]: - A_local = T.alloc_buffer((128, 128), scope="local") - B_global = T.alloc_buffer((128, 128)) - for x, y in T.grid(16, 16): - with T.block([128, 128], "A_local") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A_local[vi, vj] = 1.0 - for x, y in T.grid(16, 16): - with T.block([128, 128], "A") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A_global[vi, vj] = A_local[vi, vj] - for x, y in T.grid(16, 16): - with T.block([128, 128], "B_global") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - B_global[vi, vj] = A_global[vi, vj] + 1.0 - for x, y in T.grid(16, 16): - with T.block([128, 128], "B_global") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - B[vi, vj] = B_global[vi, vj] - with T.block([128, 128], "A_global") as [vi, vj]: - A[vi, vj] = A_global[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + for i0, j0 in T.grid(8, 8): + with T.block("scope"): + i, j = T.axis.remap("SS", [i0, j0]) + A_local = T.alloc_buffer((128, 128), scope="local") + B_global = T.alloc_buffer((128, 128)) + for x, y in T.grid(16, 16): + with T.block("A_local"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A_local[vi, vj] = 1.0 + for x, y in T.grid(16, 16): + with T.block("A"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A_global[vi, vj] = A_local[vi, vj] + for x, y in T.grid(16, 16): + with T.block("B_global"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B_global[vi, vj] = A_global[vi, vj] + 1.0 + for x, y in T.grid(16, 16): + with T.block("B_global"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B[vi, vj] = B_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("A_global"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -411,83 +476,95 @@ def cache_write_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle B_global = T.alloc_buffer((128, 128), dtype="float16") C_global = T.alloc_buffer((128, 128), dtype="float16") - with T.block([128, 128], "load_store") as [vi, vj]: - T.reads(A[vi, vj]) - T.writes(D_global[vi, vj]) - D_global.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) - with T.block([8, 8], "opaque") as [vi, vj]: - T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.evaluate( - T.tvm_load_matrix_sync( - B_global.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A.data, - vi * 2048 + vj * 16, + for i, j in T.grid(128, 128): + with T.block("load_store"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(D_global[vi, vj]) + D_global.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) + for i, j in T.grid(8, 8): + with T.block("opaque"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.evaluate( + T.tvm_load_matrix_sync( + B_global.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), 128, - 1, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) + ) + for i, j in T.grid(8, 8): + with T.block("match_buffer"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = T.match_buffer( + A[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, ) - ) - with T.block([8, 8], "match_buffer") as [vi, vj]: - T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A0 = T.match_buffer( - A[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - C0 = T.match_buffer( - C_global[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - T.evaluate( - T.tvm_load_matrix_sync( - C0.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A0.data, - A0.elem_offset, - A0.strides[0], - 1, + C0 = T.match_buffer( + C_global[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + T.evaluate( + T.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) ) - ) - with T.block([128, 128], "D") as [vi, vj]: - D[vi, vj] = D_global[vi, vj] - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = B_global[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = C_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = D_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = B_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = C_global[vi, vj] @T.prim_func @@ -498,20 +575,21 @@ def cache_write_multi_consumer() -> None: A_global = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block([128], "A_global") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A_global"): + vi = T.axis.S(128, i * 16 + j) A_global[vi] = 1.0 for j in T.grid(16): - with T.block([128], "A") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A"): + vi = T.axis.S(128, i * 16 + j) A[vi] = A_global[vi] for j in T.grid(16): - with T.block([128], "B") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("B"): + vi = T.axis.S(128, i * 16 + j) B[vi] = A[vi] + 1.0 for i in T.grid(128): - with T.block([128], "C") as [vi]: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = A[vi] @@ -522,14 +600,22 @@ def continuous_cache_write(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) B_shared = T.alloc_buffer((128, 128), scope="shared") B_local = T.alloc_buffer((128, 128), scope="local") - with T.block([128, 128], "B") as [vi, vj]: - B_local[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "B") as [vi, vj]: - B_shared[vi, vj] = B_local[vi, vj] - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = B_shared[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B_local[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B_shared[vi, vj] = B_local[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = B_shared[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 ########## Testcases for cache_read ########## diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 5235664595ad..6e956e1ee688 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -32,10 +32,15 @@ def two_elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -45,12 +50,13 @@ def two_elementwise_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in range(0, 128): for ax0, ax1 in T.grid(1, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i + ax0) - T.bind(vj, ax1) + with T.block("B"): + vi = T.axis.S(128, i + ax0) + vj = T.axis.S(128, ax1) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -59,22 +65,26 @@ def blockized_1(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128], "float32") B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([8, 8], "C_outer") as [vi_o, vj_o]: - T.reads([B[ - vi_o * 16 : vi_o * 16 + 16, - vj_o * 16 : vj_o * 16 + 16, - ]]) - T.writes([C[ - vi_o * 16 : vi_o * 16 + 16, - vj_o * 16 : vj_o * 16 + 16 - ]]) - for i_i, j_i in T.grid(16, 16): - with T.block([128, 128], "C_inner") as [vi, vj]: - T.bind(vi, vi_o * 16 + i_i) - T.bind(vj, vj_o * 16 + j_i) - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(8, 8): + with T.block("C_outer"): + vi_o, vj_o = T.axis.remap("SS", [i, j]) + T.reads([B[ + vi_o * 16 : vi_o * 16 + 16, + vj_o * 16 : vj_o * 16 + 16, + ]]) + T.writes([C[ + vi_o * 16 : vi_o * 16 + 16, + vj_o * 16 : vj_o * 16 + 16 + ]]) + for i_i, j_i in T.grid(16, 16): + with T.block("C_inner"): + vi = T.axis.S(128, vi_o * 16 + i_i) + vj = T.axis.S(128, vj_o * 16 + j_i) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -84,13 +94,12 @@ def blockized_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], "float32") for i0_0, i1_0 in T.grid(8, 8): for ax0, ax1 in T.grid(16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0_0 * 16 + ax0) - T.bind(vj, i1_0 * 16 + ax1) + with T.block("B"): + vi = T.axis.S(128, i0_0 * 16 + ax0) + vj = T.axis.S(128, i1_0 * 16 + ax1) B[vi, vj] = A[vi, vj] * 2.0 - with T.block([8, 8], "C_outer") as [vi_o, vj_o]: - T.bind(vi_o, i0_0) - T.bind(vj_o, i1_0) + with T.block("C_outer"): + vi_o, vj_o = T.axis.remap("SS", [i0_0, i1_0]) T.reads([B[ vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16, @@ -100,9 +109,9 @@ def blockized_after_compute_at(a: T.handle, c: T.handle) -> None: vj_o * 16 : vj_o * 16 + 16 ]]) for i0_1, i1_1 in T.grid(16, 16): - with T.block([128, 128], "C_inner") as [vi, vj]: - T.bind(vi, vi_o * 16 + i0_1) - T.bind(vj, vj_o * 16 + i1_1) + with T.block("C_inner"): + vi = T.axis.S(128, vi_o * 16 + i0_1) + vj = T.axis.S(128, vj_o * 16 + i1_1) C[vi, vj] = B[vi, vj] + 1.0 @@ -112,9 +121,8 @@ def blockized_2(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(8, 8): - with T.block([8, 8], "B_outer") as [vio, vjo]: - T.bind(vio, i_o) - T.bind(vjo, j_o) + with T.block("B_outer"): + vio, vjo = T.axis.remap("SS", [i_o, j_o]) T.reads([A[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, @@ -124,14 +132,14 @@ def blockized_2(a: T.handle, c: T.handle) -> None: vjo * 16 : vjo * 16 + 16 ]]) for i_i, j_i in T.grid(16, 16): - with T.block([128, 128], "B_inner") as [vi, vj]: - T.bind(vi, vio * 16 + i_i) - T.bind(vj, vjo * 16 + j_i) + with T.block("B_inner"): + vi = T.axis.S(128, vio * 16 + i_i) + vj = T.axis.S(128, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for i_o, j_o, i_i, j_i in T.grid(4, 4, 32, 32): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_o * 32 + i_i) - T.bind(vj, j_o * 32 + j_i) + with T.block("C"): + vi = T.axis.S(128, i_o * 32 + i_i) + vj = T.axis.S(128, j_o * 32 + j_i) C[vi, vj] = B[vi, vj] + 1.0 @@ -141,9 +149,8 @@ def blockized_2_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(8, 8): - with T.block([8, 8], "B_outer") as [vio, vjo]: - T.bind(vio, i_o) - T.bind(vjo, j_o) + with T.block("B_outer"): + vio, vjo = T.axis.remap("SS", [i_o, j_o]) T.reads([A[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, @@ -153,14 +160,14 @@ def blockized_2_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: vjo * 16 : vjo * 16 + 16 ]]) for i_i, j_i in T.grid(16, 16): - with T.block([128, 128], "B_inner") as [vi, vj]: - T.bind(vi, vio * 16 + i_i) - T.bind(vj, vjo * 16 + j_i) + with T.block("B_inner"): + vi = T.axis.S(128, vio * 16 + i_i) + vj = T.axis.S(128, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for ax0, ax1 in T.grid(16, 16): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_o * 16 + ax0) - T.bind(vj, j_o * 16 + ax1) + with T.block("C"): + vi = T.axis.S(128, i_o * 16 + ax0) + vj = T.axis.S(128, j_o * 16 + ax1) T.reads([B[vi, vj]]) T.writes([C[vi, vj]]) C[vi, vj] = B[vi, vj] + 1.0 @@ -173,9 +180,9 @@ def blockized_2_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(4, 4): for ax0, ax1 in T.grid(2, 2): - with T.block([8, 8], "blockized_B") as [vio, vjo]: - T.bind(vio, i_o * 2 + ax0) - T.bind(vjo, j_o * 2 + ax1) + with T.block("blockized_B"): + vio = T.axis.S(8, i_o * 2 + ax0) + vjo = T.axis.S(8, j_o * 2 + ax1) T.reads([A[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, @@ -185,14 +192,14 @@ def blockized_2_after_compute_at(a: T.handle, c: T.handle) -> None: vjo * 16 : vjo * 16 + 16, ]]) for i_i, j_i in T.grid(16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, vio * 16 + i_i) - T.bind(vj, vjo * 16 + j_i) + with T.block("B"): + vi = T.axis.S(128, vio * 16 + i_i) + vj = T.axis.S(128, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for i_i, j_i in T.grid(32, 32): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_o * 32 + i_i) - T.bind(vj, j_o * 32 + j_i) + with T.block("C"): + vi = T.axis.S(128, i_o * 32 + i_i) + vj = T.axis.S(128, j_o * 32 + j_i) C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -205,18 +212,28 @@ def cuda_matmul_0(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - A_shared_local[v0, v1] = A_shared[v0, v1] - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - B_shared_local[v0, v1] = B_shared[v0, v1] - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - with T.init(): - C_local[vi, vj] = 0.0 - C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared_local[v0, v1] = A_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared_local[v0, v1] = B_shared[v0, v1] + for i, j, k in T.grid(2048, 2048, 2048): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C_local[vi, vj] = 0.0 + C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -224,9 +241,9 @@ def cuda_matmul_0(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0_4, v1_4]: - T.bind(v0_4, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1_4, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0_4 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1_4 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0_4, v1_4] = C_local[v0_4, v1_4] @@ -240,14 +257,22 @@ def cuda_matmul_0_after_compute_at(a: T.handle, b: T.handle, c: T.handle) -> Non A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - A_shared_local[v0, v1] = A_shared[v0, v1] - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - B_shared_local[v0, v1] = B_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared_local[v0, v1] = A_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -255,17 +280,17 @@ def cuda_matmul_0_after_compute_at(a: T.handle, b: T.handle, c: T.handle) -> Non for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for i, j, k in T.grid(4, 4, 2048): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [vi, vj]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj] @@ -279,14 +304,22 @@ def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - A_shared_local[v0, v1] = A_shared[v0, v1] - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - B_shared_local[v0, v1] = B_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared_local[v0, v1] = A_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -296,17 +329,17 @@ def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k_0 * 8 + k_1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k_0 * 8 + k_1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [vi, vj]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj] @@ -320,12 +353,18 @@ def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - B_shared_local[v0, v1] = B_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -335,22 +374,22 @@ def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - T.bind(v0, k_0 * 8 + k_1 + i) - T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + with T.block("A_shared_local"): + v0 = T.axis.S(2048, k_0 * 8 + k_1 + i) + v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k_0 * 8 + k_1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k_0 * 8 + k_1) with T.init(): C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0, v1]: - T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -364,10 +403,14 @@ def cuda_matmul_3(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -377,27 +420,27 @@ def cuda_matmul_3(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for k0 in T.serial(0, 256): for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + with T.block("A_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("B_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k0 * 8 + k1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) with T.init(): C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0, v1]: - T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -411,8 +454,10 @@ def cuda_matmul_4(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -421,33 +466,33 @@ def cuda_matmul_4(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for k0 in T.serial(0, 256): for i, j in T.grid(8, 64): - with T.block([2048, 2048], "A_shared") as [v0, v1]: - T.bind(v0, k0 * 8 + i) - T.bind(v1, by * 64 + j) + with T.block("A_shared"): + v0 = T.axis.S(2048, k0 * 8 + i) + v1 = T.axis.S(2048, by * 64 + j) A_shared[v0, v1] = A[v0, v1] for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + with T.block("A_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("B_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k0 * 8 + k1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0, v1]: - T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -469,38 +514,38 @@ def cuda_matmul_5(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for k0 in T.serial(0, 256): for i, j in T.grid(8, 64): - with T.block([2048, 2048], "A_shared") as [v0, v1]: - T.bind(v0, k0 * 8 + i) - T.bind(v1, by * 64 + j) + with T.block("A_shared"): + v0 = T.axis.S(2048, k0 * 8 + i) + v1 = T.axis.S(2048, by * 64 + j) A_shared[v0, v1] = A[v0, v1] for i, j in T.grid(8, 64): - with T.block([2048, 2048], "B_shared") as [v0, v1]: - T.bind(v0, k0 * 8 + i) - T.bind(v1, bx * 64 + j) + with T.block("B_shared"): + v0 = T.axis.S(2048, k0 * 8 + i) + v1 = T.axis.S(2048, bx * 64 + j) B_shared[v0, v1] = B[v0, v1] for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + with T.block("A_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("B_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k0 * 8 + k1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0, v1]: - T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -510,12 +555,14 @@ def tiled(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i_0 * 16 + i_1) - T.bind(vj, j_0 * 16 + j_1) + with T.block("B"): + vi = T.axis.S(128, i_0 * 16 + i_1) + vj = T.axis.S(128, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -525,14 +572,14 @@ def tiled_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], "float32") for i_0, j_0, i_1 in T.grid(8, 8, 16): for j_1 in T.serial(0, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i_0 * 16 + i_1) - T.bind(vj, j_0 * 16 + j_1) + with T.block("B"): + vi = T.axis.S(128, i_0 * 16 + i_1) + vj = T.axis.S(128, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 for j_1 in T.serial(0, 16): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_0 * 16 + i_1) - T.bind(vj, j_0 * 16 + j_1) + with T.block("C"): + vi = T.axis.S(128, i_0 * 16 + i_1) + vj = T.axis.S(128, j_0 * 16 + j_1) C[vi, vj] = B[vi, vj] + 1.0 @@ -544,17 +591,15 @@ def factorized(a: T.handle, b: T.handle) -> None: for j in T.thread_binding(0, 16, thread = "blockIdx.x"): for i_o in T.thread_binding(0, 4, thread = "threadIdx.x"): for i_i, k in T.grid(4, 16): - with T.block([16, 16, T.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: - T.bind(vi, i_o * 4 + i_i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B_rf"): + vi = T.axis.S(16, i_o * 4 + i_i) + vj, vk = T.axis.remap("SR", [j, k]) with T.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for i, k in T.grid(16, 16): - with T.block([16, T.reduce_axis(0, 16)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, k) + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi] @@ -568,17 +613,17 @@ def factorized_after_reverse_compute_at(a: T.handle, b: T.handle) -> None: for j in T.thread_binding(0, 16, thread = "blockIdx.x"): for i_o in T.thread_binding(0, 4, thread = "threadIdx.x"): for i_i, k in T.grid(4, 16): - with T.block([16, 16, T.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: - T.bind(vi, i_o * 4 + i_i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B_rf"): + vi = T.axis.S(16, i_o * 4 + i_i) + vj = T.axis.S(16, j) + vk = T.axis.R(16, k) with T.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for k in T.serial(0, 4): - with T.block([16, T.reduce_axis(0, 16)], "B") as [vi, vk]: - T.bind(vi, j) - T.bind(vk, i_o * 4 + k) + with T.block("B"): + vi = T.axis.S(16, j) + vk = T.axis.R(16, i_o * 4 + k) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi] @@ -591,17 +636,19 @@ def fail_subtree_compact_dataflow(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in range(0, 128): for j in range(0, 64): - with T.block([128, 128], "B_0") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B_0"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 64): - with T.block([128, 128], "B_1") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j + 64) + with T.block("B_1"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j + 64) B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -611,13 +658,16 @@ def fail_all_consumers_under_loop(a: T.handle, c: T.handle, d: T.handle) -> None C = T.match_buffer(c, (128, 128), "float32") D = T.match_buffer(d, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "D") as [vi, vj]: + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = B[vi, vj] + 1.0 @@ -628,13 +678,16 @@ def fail_all_producers_under_loop(a: T.handle, d: T.handle) -> None: C = T.alloc_buffer((128, 128), "float32") D = T.match_buffer(d, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "D") as [vi, vj]: + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = B[vi, vj] + C[vi, vj] @@ -644,10 +697,12 @@ def read_out_of_bound(a: T.handle, c:T.handle) -> None: B = T.alloc_buffer([16], "float32") C = T.match_buffer(c, [16], "float32") for i in T.serial(0, 16): - with T.block([16], "B") as [v]: + with T.block("B"): + v = T.axis.S(16, i) B[v] = A[v] for j in T.serial(0, 16): - with T.block([16], "C") as [v]: + with T.block("C"): + v = T.axis.S(16, j) T.reads(B[v : v + 2]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") @@ -659,11 +714,11 @@ def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [16], "float32") for j in T.serial(0, 16): for i in T.serial(0, T.min(1, 15 - j) + 1): - with T.block([16], "B") as [v]: - T.bind(v, j + i) + with T.block("B"): + v = T.axis.S(16, j + i) B[v] = A[v] - with T.block([16], "C") as [v]: - T.bind(v, j) + with T.block("C"): + v = T.axis.S(16, j) T.reads([B[v : v + 2]]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index f9049f6da732..617c75b75cd9 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -31,10 +31,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -43,12 +47,18 @@ def elementwise_multi_producer_consumer(a: T.handle, c: T.handle, d: T.handle) - B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) D = T.match_buffer(d, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 # B has two consumers - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 - with T.block([128, 128], "D") as [vi, vj]: - D[vi, vj] = B[vi, vj] + 2.0 + C[vi, vj] # D has two producers + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 # B has two consumers + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = B[vi, vj] + 2.0 + C[vi, vj] # D has two producers @T.prim_func @@ -56,10 +66,14 @@ def elementwise_multi_consumer_inlined(a: T.handle, c: T.handle, d: T.handle) -> A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) D = T.match_buffer(d, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 - with T.block([128, 128], "D") as [vi, vj]: - D[vi, vj] = A[vi, vj] * 2.0 + 2.0 + C[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = A[vi, vj] * 2.0 + 2.0 + C[vi, vj] @T.prim_func @@ -67,18 +81,24 @@ def elementwise_standalone(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] + 1.0 @T.prim_func def elementwise_standalone_dce(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] + 1.0 @T.prim_func @@ -88,14 +108,12 @@ def elementwise_under_loop(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -103,8 +121,10 @@ def elementwise_under_loop(a: T.handle, c: T.handle) -> None: def elementwise_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @T.prim_func @@ -113,11 +133,15 @@ def fail_multi_reader_writer(a: T.handle, d: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.alloc_buffer((128, 128)) D = T.match_buffer(d, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - C[vi, vj] = A[vi, vj] + 2.0 - with T.block([128, 128], "C") as [vi, vj]: - D[vi, vj] = B[vi, vj] + C[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + C[vi, vj] = A[vi, vj] + 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = B[vi, vj] + C[vi, vj] @T.prim_func @@ -125,18 +149,24 @@ def elementwise_multi_reverse_loads(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = (B[vi, vj] + 1.0) * (B[vi, vj] * 2.0) + 3.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = (B[vi, vj] + 1.0) * (B[vi, vj] * 2.0) + 3.0 @T.prim_func def elementwise_multi_reverse_loads_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0 @T.prim_func @@ -144,12 +174,16 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - T.reads(B[0:128, 0:128]) - T.writes(C[0:128, 0:128]) - C[vi, vj] = T.load("float32", B.data, vi * 128 + vj) + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[0:128, 0:128]) + T.writes(C[0:128, 0:128]) + C[vi, vj] = T.load("float32", B.data, vi * 128 + vj) + 1.0 @T.prim_func @@ -157,13 +191,17 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - T.reads(B[0:128, 0:128]) - T.writes(C[0:128, 0:128]) - T.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) - C[vi, vj] = T.load("float32", B.data, vi * 16 + vj) + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[0:128, 0:128]) + T.writes(C[0:128, 0:128]) + T.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) + C[vi, vj] = T.load("float32", B.data, vi * 16 + vj) + 1.0 @T.prim_func @@ -171,11 +209,15 @@ def buffer_matched(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - Bb = T.match_buffer(B[vi : vi + 1, vj], (1, 1)) - C[vi, vj] = Bb[0, 0] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + Bb = T.match_buffer(B[vi : vi + 1, vj], (1, 1)) + C[vi, vj] = Bb[0, 0] + 1.0 @T.prim_func @@ -183,10 +225,13 @@ def elementwise_predicate(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) T.where(B[i, j] < 10.0) C[vi, vj] = B[vi, vj] + 1.0 @@ -196,7 +241,8 @@ def elementwise_predicate_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) T.where(A[i, j] * 2.0 < 10.0) C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @@ -206,18 +252,24 @@ def elementwise_multi_loads(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 126], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + B[vi, vj + 1] + B[vi, vj + 2] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + B[vi, vj + 1] + B[vi, vj + 2] @T.prim_func def elementwise_multi_loads_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 126], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0 # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_error.py b/tests/python/unittest/test_tir_schedule_error.py index 7a9c8e01d355..ad6a1931bb0b 100644 --- a/tests/python/unittest/test_tir_schedule_error.py +++ b/tests/python/unittest/test_tir_schedule_error.py @@ -31,10 +31,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) - for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + for k in range(128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index 60269ac01c14..93876c668913 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -31,9 +31,10 @@ def element_wise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) - - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -42,9 +43,8 @@ def element_wise_parallelized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i0 in T.parallel(0, 128): for i1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, i1) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, i1]) B[vi, vj] = A[vi, vj] * 2.0 @@ -54,9 +54,8 @@ def element_wise_i_bound(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i0 in T.thread_binding(0, 128, thread="threadIdx.x"): for i1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, i1) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, i1]) B[vi, vj] = A[vi, vj] * 2.0 @@ -67,14 +66,13 @@ def element_wise_compute_at_split(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j0) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o, j1i in T.grid(32, 4): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j1o * 4 + j1i) + with T.block("C"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 @@ -85,15 +83,14 @@ def element_wise_compute_at_split_vectorized(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j0) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o in T.serial(0, 32): for j1i in T.vectorized(0, 4): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j1o * 4 + j1i) + with T.block("C"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 @@ -102,10 +99,10 @@ def element_wise_split_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i, j_0, j_1 in T.grid(128, 13, 10): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): T.where(j_0 * 10 + j_1 < 128) - T.bind(vi, i) - T.bind(vj, j_0 * 10 + j_1) + vi = T.axis.S(128, i) + vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0 @@ -116,10 +113,10 @@ def element_wise_split_predicate_parallelized(a: T.handle, b: T.handle) -> None: for i in T.serial(0, 128): for j_0 in T.parallel(0, 13): for j_1 in T.serial(0, 10): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): T.where(j_0 * 10 + j_1 < 128) - T.bind(vi, i) - T.bind(vj, j_0 * 10 + j_1) + vi = T.axis.S(128, i) + vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0 @@ -129,10 +126,10 @@ def element_wise_split_predicate_vectorized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128]) for i in T.vectorized(0, 128): for j_0, j_1 in T.grid(13, 10): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): T.where(j_0 * 10 + j_1 < 128) - T.bind(vi, i) - T.bind(vj, j_0 * 10 + j_1) + vi = T.axis.S(128, i) + vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0 @@ -143,15 +140,14 @@ def element_wise_compute_at_split_j0_j1o_bound(a: T.handle, c: T.handle) -> None B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j0) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o in T.thread_binding(0, 32, thread="threadIdx.x"): for j1i in T.serial(0, 4): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j1o * 4 + j1i) + with T.block("C"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 @@ -161,10 +157,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -172,10 +170,12 @@ def rowsum(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] @T.prim_func @@ -184,9 +184,8 @@ def rowsum_unrolled(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i0 in T.unroll(0, 128): for i1 in T.serial(0, 128): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i0) - T.bind(vk, i1) + with T.block("B"): + vi, vk = T.axis.remap("SR", [i0, i1]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -198,9 +197,9 @@ def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 16): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, T.floordiv(k * k, 2)) + with T.block("B"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, T.floordiv(k * k, 2)) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -211,10 +210,12 @@ def rowsum_not_compact_data_flow(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vk] = 0.0 - B[vk] = B[vk] + A[vi, vk] + for i, k in T.grid(128, 16): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vk] = 0.0 + B[vk] = B[vk] + A[vi, vk] @T.prim_func @@ -223,9 +224,8 @@ def rowsum_cross_thread_reduction(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i0 in T.serial(0, 128): for i1 in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i0) - T.bind(vk, i1) + with T.block("B"): + vi, vk = T.axis.remap("SR", [i0, i1]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -235,10 +235,48 @@ def rowsum_cross_thread_reduction(a: T.handle, b: T.handle) -> None: def opaque_block(a: T.handle) -> None: A = T.match_buffer(a, (16,)) for i in T.serial(0, 15): - with T.block([], "opaque"): + with T.block("opaque"): A[i + 1] = A[i + 1] + A[i] +@T.prim_func +def block_inside_init(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + for i in T.serial(0, 128): + with T.block("outer"): + vi = T.axis.S(128, i) + with T.init(): + for j in T.serial(0, 128): + with T.block("init"): + vj = T.axis.S(128, j) + B[vi, vj] = 0.0 + for k in T.serial(0, 128): + for j in T.serial(0, 128): + with T.block("inner"): + vj, vk = T.axis.remap("SR", [j, k]) + B[vi, vj] = B[vi, vj] + A[vi, vj, vk] + + +@T.prim_func +def thread_bound_block_inside_init(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + for i in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("outer"): + vi = T.axis.S(128, i) + with T.init(): + for j in T.serial(0, 128): + with T.block("init"): + vj = T.axis.S(128, j) + B[vi, vj] = 0.0 + for k in T.serial(0, 128): + for j in T.serial(0, 128): + with T.block("inner"): + vj, vk = T.axis.remap("SR", [j, k]) + B[vi, vj] = B[vi, vj] + A[vi, vj, vk] + + # pylint: enable=no-member,invalid-name,unused-variable @@ -361,5 +399,13 @@ def test_bind_after_bind(): verify_trace_roundtrip(s, mod=element_wise) +def test_block_inside_init(): + s = tir.Schedule(block_inside_init, debug_mask="all") + (i,) = s.get_loops(s.get_block("outer")) + s.bind(i, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], thread_bound_block_inside_init) + verify_trace_roundtrip(s, mod=block_inside_init) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index 8460b5cf3e66..e158f6a026e1 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -32,18 +32,17 @@ def rowsum_blockized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [32, 4]) A = T.match_buffer(a, [32, 4, 128]) for i0, i2_0 in T.grid(32, 16): - with T.block([32, T.reduce_axis(0, 16)], "blockized_B") as [io, ko]: - T.bind(io, i0) - T.bind(ko, i2_0) + with T.block("blockized_B"): + io, ko = T.axis.remap("SR", [i0, i2_0]) with T.init(): for i1 in T.serial(0, 4): - with T.block([4], "B_init") as [ii_init]: - T.bind(ii_init, i1) + with T.block("B_init"): + ii_init = T.axis.S(4, i1) B[io, ii_init] = 0.0 for i1_1, i2_1 in T.grid(4, 8): - with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]: - T.bind(ii, i1_1) - T.bind(k, ko * 8 + i2_1) + with T.block("B"): + ii = T.axis.S(4, i1_1) + k = T.axis.R(128, ko * 8 + i2_1) B[io, ii] = B[io, ii] + A[io, ii, k] @@ -52,11 +51,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -65,11 +65,15 @@ def matmul_decompose0(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([128, 128], "init") as [vi, vj]: - C[vi, vj] = 0.0 + for i, j in T.grid(128, 128): + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = 0.0 - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -78,16 +82,19 @@ def matmul_decompose1(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [32, 4], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 32): - with T.block([32], "blockized_B_init") as [io]: + with T.block("blockized_B_init"): + io = T.axis.S(32, i0) for i1 in T.serial(0, 4): - with T.block([4], "B_init") as [ii]: + with T.block("B_init"): + ii = T.axis.S(4, i1) B[io, ii] = T.float32(0) for i0, i2_o in T.grid(32, 16): - with T.block([32, T.reduce_axis(0, 16)], "blockized_B_update") as [io, ko]: + with T.block("blockized_B_update"): + io, ko = T.axis.remap("SR", [i0, i2_o]) for i1, i2_i in T.grid(4, 8): - with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]: - T.bind(ii, i1) - T.bind(k, ((ko * 8) + i2_i)) + with T.block("B"): + ii = T.axis.S(4, i1) + k = T.axis.R(128, ko * 8 + i2_i) B[io, ii] = B[io, ii] + A[io, ii, k] @@ -98,10 +105,12 @@ def matmul_decompose2(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(128, 128): - with T.block([128, 128], "update_init") as [vi_init, vj_init]: + with T.block("update_init"): + vi_init, vj_init = T.axis.remap("SS", [i0, i1]) C[vi_init, vj_init] = T.float32(0) for i2 in T.serial(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [vi, vj, vk]: + with T.block("update_update"): + vi, vj, vk = T.axis.remap("SSR", [i0, i1, i2]) C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) @@ -112,12 +121,10 @@ def matmul_decompose_fail3(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, k, j in T.grid(128, 128, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -127,25 +134,21 @@ def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) for i0_0 in T.serial(0, 16): for i0_1_init, i1_init in T.grid(8, 128): - with T.block([128, 128], "update_init") as [vi_init, vj_init]: - T.bind(vi_init, ((i0_0 * 8) + i0_1_init)) - T.bind(vj_init, i1_init) + with T.block("update_init"): + vi_init = T.axis.S(128, i0_0 * 8 + i0_1_init) + vj_init = T.axis.S(128, i1_init) C[vi_init, vj_init] = T.float32(0) for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7): - with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [ - vi, - vj, - vk, - ]: + with T.block("update_update"): T.where((((i2_0 * 7) + i2_1) < 128)) - T.bind(vi, ((i0_0 * 8) + i0_1)) - T.bind(vj, i1) - T.bind(vk, ((i2_0 * 7) + i2_1)) + vi = T.axis.S(128, i0_0 * 8 + i0_1) + vj = T.axis.S(128, i1) + vk = T.axis.R(128, i2_0 * 7 + i2_1) C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index a60ab8dca972..8267a369cf5d 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -30,8 +30,10 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + for i, j, k, l in T.grid(128, 128, 128, 128): + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @T.prim_func @@ -39,11 +41,9 @@ def elementwise_not_affine(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 8): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) - T.bind(vl, l * 16) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + vl = T.axis.S(128, l * 16) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -53,7 +53,8 @@ def elementwise_dependent_loop(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128, 128)) for i in T.serial(0, 128): for j, k, l in T.grid(128, i, 128): - with T.block([128, 128, i, 128], "B") as [vi, vj, vk, vl]: + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -62,8 +63,9 @@ def elementwise_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 128): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + with T.block("B"): T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -74,16 +76,12 @@ def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block([128, 128, 128], "C") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in T.serial(0, 128): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = C[vi, vj, vk] * 2.0 @@ -92,12 +90,11 @@ def elementwise_with_loops_not_same_scope(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "A") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) for k in T.serial(0, 128): - with T.block([128], "B") as [vk]: - T.bind(vk, k) + with T.block("B"): + vk = T.axis.S(128, k) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -108,10 +105,9 @@ def elementwise_with_wrong_block_var_type(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block([128, 128, T.scan_axis(0, 128)], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + vk = T.axis.scan(128, k) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -122,11 +118,8 @@ def elementwise_reordered(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for l, j, k, i in T.grid(128, 128, 128, 128): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) - T.bind(vl, l) + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -135,11 +128,8 @@ def elementwise_reordered2(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for k, j, i, l in T.grid(128, 128, 128, 128): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) - T.bind(vl, l) + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -148,12 +138,9 @@ def elementwise_reordered_with_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for l, j, k, i in T.grid(128, 128, 128, 128): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + with T.block("B"): T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) - T.bind(vl, l) + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -161,14 +148,18 @@ def elementwise_reordered_with_predicate(a: T.handle, b: T.handle) -> None: def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") - with T.block([16, 16], "A") as [vi, vj]: - T.reads([]) - T.writes([A[0:16, 0:16]]) - T.store(A.data, vi * 16 + vj, 1) - with T.block([16, 16], "B") as [vi, vj]: - T.reads([]) - T.writes([B[0:16, 0:16]]) - T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + for i, j in T.grid(16, 16): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([A[0:16, 0:16]]) + T.store(A.data, vi * 16 + vj, 1) + for i, j in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([B[0:16, 0:16]]) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) @T.prim_func @@ -176,16 +167,14 @@ def opaque_access_reorder(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") for j, i in T.grid(16, 16): - with T.block([16, 16], "A") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) T.store(A.data, vi * 16 + vj, 1) for j, i in T.grid(16, 16): - with T.block([16, 16], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index 78b6a4696baa..bd474ed34295 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -34,10 +34,9 @@ def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - T.bind(vi, i0) - T.bind(vj, i1) - T.bind(vk, (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner)) + with T.block("update"): + vi, vj = T.axis.remap("SS", [i0, i1]) + vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) T.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) T.writes([C[vi, vj]]) with T.init(): @@ -53,18 +52,12 @@ def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: C_rf = T.alloc_buffer([4, 128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block([4, 128, 128, T.reduce_axis(0, 4), T.reduce_axis(0, 8)], "update_rf") as [ - vi2_inner_inner, - vi, - vj, - vi2_outer, - vi2_inner_outer, - ]: - T.bind(vi2_inner_inner, i2_inner_inner) - T.bind(vi, i0) - T.bind(vj, i1) - T.bind(vi2_outer, i2_outer) - T.bind(vi2_inner_outer, i2_inner_outer) + with T.block("update_rf"): + vi2_inner_inner = T.axis.S(4, i2_inner_inner) + vi = T.axis.S(128, i0) + vj = T.axis.S(128, i1) + vi2_outer = T.axis.R(4, i2_outer) + vi2_inner_outer = T.axis.R(8, i2_inner_outer) with T.init(): C_rf[vi2_inner_inner, vi, vj] = 0.0 C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + ( @@ -73,14 +66,8 @@ def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: ) for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4): - with T.block([T.reduce_axis(0, 4), 128, 128], "update") as [ - vi2_inner_inner_1, - vi_1, - vj_1, - ]: - T.bind(vi2_inner_inner_1, i2_inner_inner_1) - T.bind(vi_1, i0_1) - T.bind(vj_1, i1_1) + with T.block("update"): + vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", [i2_inner_inner_1, i0_1, i1_1]) with T.init(): C[vi_1, vj_1] = 0.0 C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1] @@ -93,13 +80,17 @@ def matmul_not_stage_pipeline(a: T.handle, b: T.handle, d: T.handle) -> None: D = T.match_buffer(d, [256, 256]) C = T.alloc_buffer([256, 256]) - with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([256, 256], "D") as [vi, vj]: - D[vi, vj] = C[vi, vj] + for i, j in T.grid(256, 256): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = C[vi, vj] @T.prim_func @@ -108,10 +99,12 @@ def matmul_not_same_buffer_access(a: T.handle, b: T.handle, c: T.handle) -> None B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj] @T.prim_func @@ -122,17 +115,13 @@ def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.ha D = T.match_buffer(d, [128, 128]) for k, i, j in T.grid(128, 128, 128): - with T.block([T.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]: - T.bind(ck, k) - T.bind(ci, i) - T.bind(cj, j) + with T.block("C"): + ck, ci, cj = T.axis.remap("RSS", [k, i, j]) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] - with T.block([T.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]: - T.bind(dk, k) - T.bind(di, i) - T.bind(dj, j) + with T.block("D"): + dk, di, dj = T.axis.remap("RSS", [k, i, j]) with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj] @@ -143,10 +132,12 @@ def square_sum(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) C = T.match_buffer(c, [16]) - with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: - with T.init(): - C[b] = 0.0 - C[b] = C[b] + A[b, i, j] * A[b, i, j] + for b0, i0, j0 in T.grid(16, 256, 256): + with T.block("C"): + b, i, j = T.axis.remap("SRR", [b0, i0, j0]) + with T.init(): + C[b] = 0.0 + C[b] = C[b] + A[b, i, j] * A[b, i, j] @T.prim_func @@ -156,18 +147,15 @@ def square_sum_rfactor(a: T.handle, c: T.handle) -> None: C_rf = T.alloc_buffer([16, 256]) for i0, i1, i2 in T.grid(16, 256, 256): - with T.block([256, 16, T.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: - T.bind(vi2, i2) - T.bind(b, i0) - T.bind(i, i1) + with T.block("C_rf"): + vi2, b, i = T.axis.remap("SSR", [i2, i0, i1]) with T.init(): C_rf[b, vi2] = 0.0 C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) for i0_1, i2_1 in T.grid(16, 256): - with T.block([T.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: - T.bind(vi2_1, i2_1) - T.bind(b_1, i0_1) + with T.block("C"): + vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[b_1, vi2_1] @@ -180,18 +168,18 @@ def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: - T.bind(b, i0) - T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) - T.bind(j, T.floormod(i1_i2_fused_outer, 256)) + with T.block("C"): + b = T.axis.S(16, i0) + i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) + j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) T.reads([C[b], A[b, i, j]]) T.writes([C[b]]) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): - with T.block([16], "D") as [b_1]: - T.bind(b_1, i0_1) + with T.block("D"): + b_1 = T.axis.S(16, i0_1) T.reads([C[b_1]]) T.writes([D[b_1]]) D[b_1] = T.sqrt(C[b_1], dtype="float32") @@ -205,31 +193,24 @@ def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: C_rf = T.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block([1, 16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C_rf") as [ - vi1_i2_fused_inner, - b, - i, - j, - ]: - T.bind(vi1_i2_fused_inner, i1_i2_fused_inner) - T.bind(b, i0) - T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) - T.bind(j, T.floormod(i1_i2_fused_outer, 256)) + with T.block("C_rf"): + vi1_i2_fused_inner, b = T.axis.remap("SS", [i1_i2_fused_inner, i0]) + i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) + j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) with T.init(): C_rf[vi1_i2_fused_inner, b] = 0.0 C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): - with T.block([T.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: - T.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) - T.bind(b_1, i0_1) + with T.block("C"): + vi1_i2_fused_inner_1, b_1 = T.axis.remap("RS", [i1_i2_fused_inner_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] for i0_2 in T.serial(0, 16): - with T.block([16], "D") as [b_2]: - T.bind(b_2, i0_2) + with T.block("D"): + b_2 = T.axis.S(16, i0_2) D[b_2] = T.sqrt(C[b_2], dtype="float32") @@ -238,8 +219,10 @@ def element_wise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -247,10 +230,12 @@ def rowsum(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] @T.prim_func @@ -259,9 +244,9 @@ def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 16): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, T.floordiv(k * k, 2)) + with T.block("B"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, T.floordiv(k * k, 2)) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -272,10 +257,12 @@ def rowsum_not_dominant(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi, vk] = 0.0 - B[vi, vk] = B[vi, vk] + A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi, vk] = 0.0 + B[vi, vk] = B[vi, vk] + A[vi, vk] @T.prim_func @@ -285,9 +272,8 @@ def rowsum_not_serial(a: T.handle, b: T.handle) -> None: for i in T.serial(0, 128): for k in T.parallel(0, 128): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, k) + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -298,10 +284,12 @@ def rowsum_wrong_reduce_pattern1(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 1.0 - B[vi] = B[vi] + A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 1.0 + B[vi] = B[vi] + A[vi, vk] @T.prim_func @@ -309,10 +297,12 @@ def rowsum_wrong_reduce_pattern2(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] - A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] - A[vi, vk] @T.prim_func @@ -321,9 +311,9 @@ def rowsum_transformed(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for io, ii_ko_fused, ki in T.grid(32, 128, 4): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, io * 4 + T.floordiv(ii_ko_fused, 32)) - T.bind(vk, T.floormod(ii_ko_fused, 32) * 4 + ki) + with T.block("B"): + vi = T.axis.S(128, io * 4 + T.floordiv(ii_ko_fused, 32)) + vk = T.axis.R(128, T.floormod(ii_ko_fused, 32) * 4 + ki) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -334,10 +324,12 @@ def rowsum_zero_dim(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128]) B = T.match_buffer(b, []) - with T.block([T.reduce_axis(0, 128)], "B") as [k]: - with T.init(): - B[()] = 0.0 - B[()] = B[()] + A[k] + for k0 in range(128): + with T.block("B"): + k = T.axis.R(128, k0) + with T.init(): + B[()] = 0.0 + B[()] = B[()] + A[k] @T.prim_func @@ -346,15 +338,19 @@ def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, []) B_rf = T.alloc_buffer([128]) - with T.block([128], "B_rf") as [vi0]: - with T.init(): - B_rf[vi0] = 0.0 - B_rf[vi0] = B_rf[vi0] + A[vi0] + for i in range(128): + with T.block("B_rf"): + vi0 = T.axis.S(128, i) + with T.init(): + B_rf[vi0] = 0.0 + B_rf[vi0] = B_rf[vi0] + A[vi0] - with T.block([T.reduce_axis(0, 128)], "B") as [vi0_1]: - with T.init(): - B[()] = 0.0 - B[()] = B[()] + B_rf[vi0_1] + for i in range(128): + with T.block("B"): + vi0_1 = T.axis.R(128, i) + with T.init(): + B[()] = 0.0 + B[()] = B[()] + B_rf[vi0_1] @T.prim_func @@ -362,10 +358,10 @@ def rowsum_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.block("B"): T.where(k_0 * 10 + k_1 < 128) - T.bind(vi, i) - T.bind(vk, k_0 * 10 + k_1) + vi = T.axis.S(128, i) + vk = T.axis.R(128, k_0 * 10 + k_1) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -377,18 +373,15 @@ def rowsum_predicate_rfactor(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128], dtype="float32") B_rf = T.alloc_buffer([128, 13], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): - with T.block([13, 128, T.reduce_axis(0, 10)], "B_rf") as [vk_0, vi, vk_1]: + with T.block("B_rf"): + vk_0, vi, vk_1 = T.axis.remap("SSR", [k_0, i, k_1]) T.where(k_0 * 10 + k_1 < 128) - T.bind(vk_0, k_0) - T.bind(vi, i) - T.bind(vk_1, k_1) with T.init(): B_rf[vi, vk_0] = T.float32(0) B_rf[vi, vk_0] = B_rf[vi, vk_0] + A[vi, vk_0 * 10 + vk_1] for i, k_0 in T.grid(128, 13): - with T.block([T.reduce_axis(0, 13), 128], "B") as [vk_0, vi]: - T.bind(vk_0, k_0) - T.bind(vi, i) + with T.block("B"): + vk_0, vi = T.axis.remap("RS", [k_0, i]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + B_rf[vi, vk_0] @@ -405,35 +398,31 @@ def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None: for i in T.serial(0, 16): for j1 in T.serial(0, 16): for k1o, k1i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "C") as [ci, cj, ck]: - T.bind(ci, i) - T.bind(cj, j1) - T.bind(ck, k1o * 4 + k1i) + with T.block("C"): + ci, cj = T.axis.remap("SS", [i, j1]) + ck = T.axis.R(16, k1o * 4 + k1i) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, cj, ck] for k2o, k2i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: - T.bind(di, i) - T.bind(dj, j1) - T.bind(dk, k2o * 4 + k2i) + with T.block("D"): + di, dj = T.axis.remap("SS", [i, j1]) + dk = T.axis.R(16, k2o * 4 + k2i) with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] for j2 in T.serial(0, 16): for k3o, k3i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: - T.bind(ei, i) - T.bind(ej, j2) - T.bind(ek, k3o * 4 + k3i) + with T.block("E"): + ei, ej = T.axis.remap("SS", [i, j2]) + ek = T.axis.R(16, k3o * 4 + k3i) with T.init(): E[ei, ej] = 0.0 E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] for k4o, k4i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: - T.bind(fi, i) - T.bind(fj, j2) - T.bind(fk, k4o * 4 + k4i) + with T.block("F"): + fi, fj = T.axis.remap("SS", [i, j2]) + fk = T.axis.R(16, k4o * 4 + k4i) with T.init(): F[fi, fj] = 0.0 F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj] @@ -449,46 +438,38 @@ def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: C_rf = T.alloc_buffer([16, 16, 4]) for i, j1, k1o, k1i in T.grid(16, 16, 4, 4): - with T.block([4, 16, 16, T.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: - T.bind(vk1o, k1o) - T.bind(ci, i) - T.bind(cj, j1) - T.bind(vk1i, k1i) + with T.block("C_rf"): + vk1o, ci, cj, vk1i = T.axis.remap("SSSR", [k1o, i, j1, k1i]) with T.init(): C_rf[ci, cj, vk1o] = 0.0 C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] for i_1 in T.serial(0, 16): for j1_1 in T.serial(0, 16): for k1o_1 in T.serial(0, 4): - with T.block([T.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: - T.bind(vk1o_1, k1o_1) - T.bind(ci_1, i_1) - T.bind(cj_1, j1_1) + with T.block("C"): + vk1o_1, ci_1, cj_1 = T.axis.remap("RSS", [k1o_1, i_1, j1_1]) with T.init(): C[ci_1, cj_1] = 0.0 C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] for k2o, k2i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: - T.bind(di, i_1) - T.bind(dj, j1_1) - T.bind(dk, (k2o * 4) + k2i) + with T.block("D"): + di, dj = T.axis.remap("SS", [i_1, j1_1]) + dk = T.axis.R(16, k2o * 4 + k2i) with T.init(): D[di, dj] = 0.0 D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] for j2 in T.serial(0, 16): for k3o, k3i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: - T.bind(ei, i_1) - T.bind(ej, j2) - T.bind(ek, (k3o * 4) + k3i) + with T.block("E"): + ei, ej = T.axis.remap("SS", [i_1, j2]) + ek = T.axis.R(16, k3o * 4 + k3i) with T.init(): E[ei, ej] = 0.0 E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] for k4o, k4i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: - T.bind(fi, i_1) - T.bind(fj, j2) - T.bind(fk, (k4o * 4) + k4i) + with T.block("F"): + fi, fj = T.axis.remap("SS", [i_1, j2]) + fk = T.axis.R(16, k4o * 4 + k4i) with T.init(): F[fi, fj] = 0.0 F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj] diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index c93c7ca63aa8..fbf0a6a5bd78 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -32,8 +32,10 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for i, j, k in T.grid(128, 128, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 29cfe8cadfb3..d2365c39c9cb 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -30,8 +30,10 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for i, j, k in T.grid(128, 128, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @T.prim_func @@ -40,7 +42,10 @@ def elementwise_dependent_loops(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i in T.serial(0, 128): for j, k in T.grid(i, 128): - with T.block([128, i, 128], "B") as [vi, vj, vk]: + with T.block("B"): + vi = T.axis.S(128, i) + vj = T.axis.S(i, j) + vk = T.axis.S(128, k) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -49,7 +54,8 @@ def elementwise_symbolic(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i, j, k in T.grid(128, 128, n): - with T.block([128, 128, n], "B") as [vi, vj, vk]: + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -58,10 +64,10 @@ def elementwise_symbolic_fused(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i_j_k_fused in T.serial(0, (n * 16384)): - with T.block([128, 128, n], "B") as [vi, vj, vk]: - T.bind(vi, T.floordiv(i_j_k_fused, (n * 128))) - T.bind(vj, T.floormod(T.floordiv(i_j_k_fused, n), 128)) - T.bind(vk, T.floormod(i_j_k_fused, n)) + with T.block("B"): + vi = T.axis.S(128, T.floordiv(i_j_k_fused, n * 128)) + vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, n), 128)) + vk = T.axis.S(n, T.floormod(i_j_k_fused, n)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -72,11 +78,10 @@ def elementwise_symbolic_split(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i, j, k0, k1 in T.grid(128, 128, 10, T.floordiv((n + 9), 10)): - with T.block([128, 128, n], "B") as [vi, vj, vk]: + with T.block("B"): T.where((((k0 * T.floordiv((n + 9), 10)) + k1) < n)) - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, ((k0 * T.floordiv((n + 9), 10)) + k1)) + vi, vj = T.axis.remap("SS", [i, j]) + vk = T.axis.S(n, k0 * T.floordiv(n + 9, 10) + k1) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -89,10 +94,12 @@ def elementwise_with_seq(a: T.handle, b: T.handle) -> None: C = T.alloc_buffer((128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block([128, 128, 128], "C") as [vi, vj, vk]: + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in T.serial(0, 128): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = C[vi, vj, vk] * 2.0 @@ -102,10 +109,8 @@ def elementwise_with_anno(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128, annotations={"useless_annotation": True}): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -117,10 +122,8 @@ def elementwise_with_thread_binding(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -132,10 +135,8 @@ def elementwise_with_starting_point(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(10, 128): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -146,13 +147,11 @@ def elementwise_with_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block([], "opaque"): + with T.block("opaque"): T.reads([A[i, j, k]]) T.writes([B[i, j, k]]) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -163,10 +162,10 @@ def elementwise_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for fused in T.serial(0, 2097152): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, T.floordiv(fused, 16384)) - T.bind(vj, T.floormod(T.floordiv(fused, 128), 128)) - T.bind(vk, T.floormod(fused, 128)) + with T.block("B"): + vi = T.axis.S(128, T.floordiv(fused, 16384)) + vj = T.axis.S(128, T.floormod(T.floordiv(fused, 128), 128)) + vk = T.axis.S(128, T.floormod(fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -177,10 +176,10 @@ def elementwise_split_case0(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, k1, k2 in T.grid(2, 1, 64, 4, 32, 16, 8): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, ((i1 * 64) + i3)) - T.bind(vj, ((j1 * 32) + j2)) - T.bind(vk, ((k1 * 8) + k2)) + with T.block("B"): + vi = T.axis.S(128, i1 * 64 + i3) + vj = T.axis.S(128, j1 * 32 + j2) + vk = T.axis.S(128, k1 * 8 + k2) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -191,10 +190,10 @@ def elementwise_split_case1(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, j3, k1, k2, k3 in T.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i1 * 64 + i3) - T.bind(vj, j1 * 64 + j3) - T.bind(vk, k1 * 64 + k3) + with T.block("B"): + vi = T.axis.S(128, i1 * 64 + i3) + vj = T.axis.S(128, j1 * 64 + j3) + vk = T.axis.S(128, k1 * 64 + k3) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -205,16 +204,11 @@ def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) A = T.match_buffer(a, [128, 128, 128]) for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.where( - ( - ((((((i0 * 2) + i1) * 3) + i2) < 128) and (((j0 * 129) + j1) < 128)) - and (((k0 * 43) + k1) < 128) - ) - ) - T.bind(vi, (((i0 * 6) + (i1 * 3)) + i2)) - T.bind(vj, j1) - T.bind(vk, ((k0 * 43) + k1)) + with T.block("B"): + T.where((i0 * 2 + i1) * 3 + i2 < 128 and j0 * 129 + j1 < 128 and k0 * 43 + k1 < 128) + vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2) + vj = T.axis.S(128, j1) + vk = T.axis.S(128, k0 * 43 + k1) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -225,7 +219,7 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) A = T.match_buffer(a, [128, 128, 128]) for i_j_k_fused in T.serial(0, 2097152): - with T.block([], "opaque"): + with T.block("opaque"): T.reads( [ A[ @@ -244,10 +238,10 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: ] ] ) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, T.floordiv(i_j_k_fused, 16384)) - T.bind(vj, T.floormod(T.floordiv(i_j_k_fused, 128), 128)) - T.bind(vk, T.floormod(i_j_k_fused, 128)) + with T.block("B"): + vi = T.axis.S(128, T.floordiv(i_j_k_fused, 16384)) + vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, 128), 128)) + vk = T.axis.S(128, T.floormod(i_j_k_fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -259,13 +253,12 @@ def elementwise_split_with_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) for i0, i1, j, k in T.grid(8, 16, 128, 128): - with T.block([], "opaque"): + with T.block("opaque"): T.reads([A[i0 * 16 + i1, j, k]]) T.writes([B[i0 * 16 + i1, j, k]]) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i0 * 16 + i1) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi = T.axis.S(128, i0 * 16 + i1) + vj, vk = T.axis.remap("SS", [j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -275,14 +268,18 @@ def elementwise_split_with_opaque_block(a: T.handle, b: T.handle) -> None: def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") - with T.block([16, 16], "A") as [vi, vj]: - T.reads([]) - T.writes([A[0:16, 0:16]]) - T.store(A.data, vi * 16 + vj, 1) - with T.block([16, 16], "B") as [vi, vj]: - T.reads([]) - T.writes([B[0:16, 0:16]]) - T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + for i, j in T.grid(16, 16): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([A[0:16, 0:16]]) + T.store(A.data, vi * 16 + vj, 1) + for i, j in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([B[0:16, 0:16]]) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) @T.prim_func @@ -290,16 +287,16 @@ def opaque_access_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16]) B = T.match_buffer(b, [16, 16]) for i_j_fused in T.serial(0, 256): - with T.block([16, 16], "A") as [vi, vj]: - T.bind(vi, T.floordiv(i_j_fused, 16)) - T.bind(vj, T.floormod(i_j_fused, 16)) + with T.block("A"): + vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) + vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([A[0:16, 0:16]]) T.store(A.data, ((vi * 16) + vj), 1, 1) for i_j_fused in T.serial(0, 256): - with T.block([16, 16], "B") as [vi, vj]: - T.bind(vi, T.floordiv(i_j_fused, 16)) - T.bind(vj, T.floormod(i_j_fused, 16)) + with T.block("B"): + vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) + vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle")) @@ -310,16 +307,16 @@ def opaque_access_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16)) B = T.match_buffer(b, (16, 16)) for i, j0, j1 in T.grid(16, 4, 4): - with T.block([16, 16], "A") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, ((j0 * 4) + j1)) + with T.block("A"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) T.writes([A[0:16, 0:16]]) T.store(A.data, ((vi * 16) + vj), 1, 1) for i, j0, j1 in T.grid(16, 4, 4): - with T.block([16, 16], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, ((j0 * 4) + j1)) + with T.block("B"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle")) @@ -331,9 +328,9 @@ def elementwise_not_affine(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (127, 128)) for i in T.serial(0, 4): for j, k in T.grid(T.min(31, 126 - i * 32) + 1, 128): - with T.block([127, 128], "B") as [vi, vj]: - T.bind(vi, i * 32 + j) - T.bind(vj, k) + with T.block("B"): + vi = T.axis.S(127, i * 32 + j) + vj = T.axis.S(128, k) B[vi, vj] = A[vi, vj] @@ -343,12 +340,12 @@ def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [127, 128]) for i in T.grid(4): for j_k_fused in T.serial(0, T.min(31, 126 - i * 32) * 128 + 128): - with T.block([127, 128], "B") as [vi, vj]: - T.bind( - vi, + with T.block("B"): + vi = T.axis.S( + 127, i * 32 + T.floormod(T.floordiv(j_k_fused, 128), T.min(31, 126 - i * 32) + 1), ) - T.bind(vj, T.floormod(j_k_fused, 128)) + vj = T.axis.S(128, T.floormod(j_k_fused, 128)) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = A[vi, vj] diff --git a/tests/python/unittest/test_tir_schedule_state.py b/tests/python/unittest/test_tir_schedule_state.py index 94e1b4a6b395..bc62fa1ba950 100644 --- a/tests/python/unittest/test_tir_schedule_state.py +++ b/tests/python/unittest/test_tir_schedule_state.py @@ -32,10 +32,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -44,10 +48,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -55,22 +61,28 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: def block_in_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") - with T.block([128], "B") as vi: - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - B[vi, 0] = A[vi, 0] - if A[vi, 0] == 0.0: - with T.block([], "C"): - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - with T.block([128], "D") as vj: - B[vi, vj] = A[vi, vj] * 3.0 - else: - with T.block([], "E"): - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - with T.block([128], "F") as vj: - B[vi, vj] = A[vi, vj] * 2.0 + for i in range(128): + with T.block("B"): + vi = T.axis.S(128, i) + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + B[vi, 0] = A[vi, 0] + if A[vi, 0] == 0.0: + with T.block("C"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + for j in range(128): + with T.block("D"): + vj = T.axis.S(128, j) + B[vi, vj] = A[vi, vj] * 3.0 + else: + with T.block("E"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + for j in range(128): + with T.block("F"): + vj = T.axis.S(128, j) + B[vi, vj] = A[vi, vj] * 2.0 # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index e2b39ce7c289..e3bd000c2e70 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -32,10 +32,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -44,10 +48,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = 0.0 for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -55,22 +61,28 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: def block_in_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") - with T.block([128], "B") as vi: - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - B[vi, 0] = A[vi, 0] - if A[vi, 0] == 0.0: - with T.block([], "C"): - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - with T.block([128], "D") as vj: - B[vi, vj] = A[vi, vj] * 3.0 - else: - with T.block([], "E"): - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - with T.block([128], "F") as vj: - B[vi, vj] = A[vi, vj] * 2.0 + for i in range(128): + with T.block("B"): + vi = T.axis.S(128, i) + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + B[vi, 0] = A[vi, 0] + if A[vi, 0] == 0.0: + with T.block("C"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + for j in range(128): + with T.block("D"): + vj = T.axis.S(128, j) + B[vi, vj] = A[vi, vj] * 3.0 + else: + with T.block("E"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + for j in range(128): + with T.block("F"): + vj = T.axis.S(128, j) + B[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -78,10 +90,14 @@ def write_after_read(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -90,9 +106,11 @@ def loop_carried_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128,)) C = T.match_buffer(c, (128,)) for i in range(0, 128): - with T.block([128], "B") as vi: + with T.block("B"): + vi = T.axis.S(128, i) B[vi] = A[vi] * 2.0 - with T.block([128], "C") as vi: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = T.if_then_else(vi >= 1, B[vi - 1] + 1.0, 0.0, dtype="float32") @@ -101,14 +119,17 @@ def concatenate_multi_producer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) for i in range(0, 64): - with T.block([64], "A_0") as vi: + with T.block("A_0"): + vi = T.axis.S(64, i) A[vi] = vi + 1 for i in range(0, 64): - with T.block([64], "A_1") as vi: - T.bind(vi, i + 64) + with T.block("A_1"): + vi = T.axis.S(64, i + 64) A[vi] = vi + 2 - with T.block([128], "B") as vi: - B[vi] = A[vi] * 2.0 + for i in range(0, 128): + with T.block("B"): + vi = T.axis.S(128, i) + B[vi] = A[vi] * 2.0 @T.prim_func @@ -116,14 +137,17 @@ def concatenate_multi_producer_uncovered(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) for i in range(0, 63): - with T.block([63], "A_0") as vi: + with T.block("A_0"): + vi = T.axis.S(63, i) A[vi] = vi + 1 for i in range(0, 64): - with T.block([64], "A_1") as vi: - T.bind(vi, i + 64) + with T.block("A_1"): + vi = T.axis.S(64, i + 64) A[vi] = vi + 2 - with T.block([128], "B") as vi: - B[vi] = A[vi] * 2.0 + for i in range(0, 128): + with T.block("B"): + vi = T.axis.S(128, i) + B[vi] = A[vi] * 2.0 @T.prim_func @@ -132,9 +156,11 @@ def lca_at_loop(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128,)) C = T.match_buffer(c, (128,)) for i in range(0, 128): - with T.block([128], "B") as vi: + with T.block("B"): + vi = T.axis.S(128, i) B[vi] = A[vi] * 2.0 - with T.block([128], "C") as vi: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = B[vi] + 1.0 @@ -143,18 +169,20 @@ def multi_producer_consumer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) for i in range(0, 64): - with T.block([64], "A_0") as vi: + with T.block("A_0"): + vi = T.axis.S(64, i) A[vi] = vi + 1 for i in range(0, 64): - with T.block([64], "A_1") as vi: - T.bind(vi, i + 64) + with T.block("A_1"): + vi = T.axis.S(64, i + 64) A[vi] = vi + 2 for i in range(0, 64): - with T.block([64], "B_0") as vi: + with T.block("B_0"): + vi = T.axis.S(64, i) B[vi] = A[vi] + 2.0 for i in range(0, 64): - with T.block([64], "B_1") as vi: - T.bind(vi, i + 64) + with T.block("B_1"): + vi = T.axis.S(64, i + 64) B[vi] = A[vi] + 3.0 @@ -164,12 +192,14 @@ def elementwise_affine_producer(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") for i, j, k, l in T.grid(16, 2, 32, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i * 8 + j * 4 + k // 8) - T.bind(vj, k % 8 * 16 + l) + with T.block("B"): + vi = T.axis.S(128, i * 8 + j * 4 + k // 8) + vj = T.axis.S(128, k % 8 * 16 + l) B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -177,13 +207,19 @@ def elementwise_subblock(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([32, 32], "B") as [vi, vj]: - T.reads([A[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) - T.writes([B[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) - with T.block([4, 4], "B_sub") as [vi_i, vj_i]: - B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(32, 32): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([A[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) + T.writes([B[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) + for ii, jj in T.grid(4, 4): + with T.block("B_sub"): + vi_i, vj_i = T.axis.remap("SS", [ii, jj]) + B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -191,13 +227,19 @@ def elementwise_subblock_uncovered(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([32, 32], "B") as [vi, vj]: - T.reads([A[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) - T.writes([B[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) - with T.block([2, 2], "B_sub") as [vi_i, vj_i]: - B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(32, 32): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([A[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) + T.writes([B[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) + for ii, jj in T.grid(2, 2): + with T.block("B_sub"): + vi_i, vj_i = T.axis.remap("SS", [ii, jj]) + B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -207,10 +249,12 @@ def bound_to_thread(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], scope="shared") for i in T.thread_binding(0, 128, thread="threadIdx.x"): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vj, vi] = B[vj, vi] + 1.0 @@ -222,14 +266,14 @@ def equal_ranked_threads(a: T.handle, c: T.handle) -> None: for i_o in T.thread_binding(0, 16, thread="threadIdx.x"): for i_i in T.thread_binding(0, 8, thread="threadIdx.y"): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i_o * 8 + i_i) - T.bind(vj, j) + with T.block("B"): + vi = T.axis.S(128, i_o * 8 + i_i) + vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_o * 8 + i_i) - T.bind(vj, j) + with T.block("C"): + vi = T.axis.S(128, i_o * 8 + i_i) + vj = T.axis.S(128, j) C[vj, vi] = B[vj, vi] + 1.0 @@ -241,10 +285,12 @@ def warp_memory(a: T.handle, c: T.handle) -> None: for i_o in T.thread_binding(0, 4, thread="threadIdx.y"): for i_i in T.thread_binding(0, 32, thread="threadIdx.x"): for j in T.serial(0, 128): - with T.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: + with T.block("B"): + warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 for j in T.serial(0, 128): - with T.block([4, 32, 128], "C") as [warp_id, lane_id, vj]: + with T.block("C"): + warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 @@ -256,11 +302,15 @@ def warp_memory_negative(a: T.handle, c: T.handle) -> None: for i_o in T.thread_binding(0, 4, thread="threadIdx.y"): for i_i in T.thread_binding(0, 32, thread="threadIdx.x"): for j in T.serial(0, 128): - with T.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: + with T.block("B"): + warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 for i_o_prime in T.thread_binding(0, 4, thread="threadIdx.y"): for j in T.serial(0, 128): - with T.block([4, 32, 4, 128], "C") as [_warp_id, lane_id, warp_id, vj]: + with T.block("C"): + _warp_id, warp_id, lane_id, vj = T.axis.remap( + "SSSS", [i_o, i_i, i_o_prime, j] + ) C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py b/tests/python/unittest/test_tir_schedule_storage_align.py index 7d0e91f70e60..3b699fd8f1b2 100644 --- a/tests/python/unittest/test_tir_schedule_storage_align.py +++ b/tests/python/unittest/test_tir_schedule_storage_align.py @@ -29,22 +29,20 @@ def element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, ax1) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): - with T.block([128, 128], "C") as [vi_1, vj_1]: - T.bind(vi_1, i0) - T.bind(vj_1, i1) + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) @@ -55,23 +53,21 @@ def element_wise_storage_align(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, ax1) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) T.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]}) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): - with T.block([128, 128], "C") as [vi_1, vj_1]: - T.bind(vi_1, i0) - T.bind(vj_1, i1) + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) @@ -82,23 +78,21 @@ def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): T.block_attr({"buffer_dim_align": [0]}) - T.bind(vi, i0) - T.bind(vj, ax1) + vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): - with T.block([128, 128], "C") as [vi_1, vj_1]: - T.bind(vi_1, i0) - T.bind(vj_1, i1) + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index 36e05c6b5170..f1c97c57b2ff 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -32,18 +32,24 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func def elementwise_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 185d229b44e1..440d0ab67a50 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -34,10 +34,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] diff --git a/tests/python/unittest/test_tir_specialize.py b/tests/python/unittest/test_tir_specialize.py index 86dc5dffed9f..72666a89ebcb 100644 --- a/tests/python/unittest/test_tir_specialize.py +++ b/tests/python/unittest/test_tir_specialize.py @@ -27,10 +27,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle, n: T.int32) -> None: B = T.match_buffer(b, [m, n]) C = T.match_buffer(c, [m, m]) - with T.block([m, m, T.reduce_axis(0, n)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(m, m, n): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -39,10 +41,12 @@ def matmul_128(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -52,10 +56,12 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [m, 128]) C = T.match_buffer(c, [m, m]) - with T.block([m, m, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(m, m, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -66,10 +72,12 @@ def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [m, x * 8]) C = T.match_buffer(c, [m, m]) - with T.block([m, m, T.reduce_axis(0, x * 8)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(m, m, x * 8): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -81,11 +89,15 @@ def element_wise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((m, n), "float32") - with T.block([m, n], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(m, n): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 - with T.block([m, n], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(m, n): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -94,11 +106,15 @@ def element_wise_128_64(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 64), "float32") B = T.alloc_buffer((128, 64), "float32") - with T.block([128, 64], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 64): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 64], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 64): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -108,11 +124,15 @@ def element_wise_128_n(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, n), "float32") B = T.alloc_buffer((128, n), "float32") - with T.block([128, n], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, n): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, n], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, n): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -120,8 +140,10 @@ def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int32, q: T. A = T.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=q) B = T.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q) - with T.block([m, n], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(m, n): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] @T.prim_func @@ -129,8 +151,10 @@ def mem_copy_16_16_8_4(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32", strides=[8, 1], elem_offset=4) B = T.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4) - with T.block([16, 16], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] @T.prim_func @@ -138,8 +162,10 @@ def mem_copy_m_n_p_n(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int3 A = T.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=n) B = T.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n) - with T.block([m, n], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(m, n): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] @T.prim_func @@ -147,8 +173,10 @@ def param_in_arith_exprs(a: T.handle, b: T.handle) -> None: n = T.var("int32") A = T.match_buffer(a, [n // 8, 8], "int32") B = T.match_buffer(b, [n], "int32") - with T.block([n - 1], "") as [vi]: - B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 + for i in range(n - 1): + with T.block(): + vi = T.axis.S(n - 1, i) + B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 @T.prim_func @@ -156,8 +184,10 @@ def param_in_arith_exprs_n_16(a: T.handle, b: T.handle) -> None: n = T.var("int32") A = T.match_buffer(a, [2, 8], "int32") B = T.match_buffer(b, [16], "int32") - with T.block([15], "") as [vi]: - B[vi] = A[vi // 8, vi % 8] + 714 + for i in range(15): + with T.block(): + vi = T.axis.S(15, i) + B[vi] = A[vi // 8, vi % 8] + 714 def test_specialize_nothing(): diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 0cfc724e41de..7d3115428f5a 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -32,17 +32,17 @@ def elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] * 2.0 @@ -53,7 +53,7 @@ def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((1, 16), "float32") @@ -74,7 +74,7 @@ def unschedulable_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") @@ -89,11 +89,11 @@ def param_buffer_access_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (20, 20), "float32") B = T.match_buffer(c, (20, 20), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(B[i, 0:16]) for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 @@ -106,17 +106,17 @@ def shared_mem_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block([]): + with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((16, 16), "float32", scope="shared") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i0 * 8 + i1 * 4 + i2, j]) B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i0 * 8 + i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0 @@ -129,17 +129,17 @@ def compacted_shared_mem_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block([]): + with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((8, 16), "float32", scope="shared") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i1 * 4 + i2, j]) B[i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i1 * 4 + i2, j] * 2.0 @@ -152,17 +152,17 @@ def warp_mem_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block([]): + with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((16, 16), "float32", scope="warp") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i0 * 8 + i1 * 4 + i2, j]) B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i0 * 8 + i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0 @@ -175,17 +175,17 @@ def compacted_warp_mem_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block([]): + with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((4, 16), "float32", scope="warp") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i2, j]) B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0 @@ -196,17 +196,17 @@ def symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (n * 8,), "float32") C = T.match_buffer(c, (n * 8,), "float32") for i in range(0, n): - with T.block([]): + with T.block(): T.reads(A[i * 8 : i * 8 + 8]) T.writes(C[i * 8 : i * 8 + 8]) B = T.alloc_buffer((n * 8,), "float32") for j in range(0, 8): - with T.block([]) as []: + with T.block() as []: T.reads(A[i * 8 + j]) T.writes(B[i * 8 + j]) B[i * 8 + j] = A[i * 8 + j] + 1.0 for j in range(0, 8): - with T.block([]) as []: + with T.block() as []: T.reads(B[i * 8 + j]) T.writes(C[i * 8 + j]) C[i * 8 + j] = B[i * 8 + j] * 2.0 @@ -217,17 +217,17 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (n * 8,), "float32") C = T.match_buffer(c, (n * 8,), "float32") for i in range(0, n): - with T.block([]): + with T.block(): T.reads(A[i * 8 : i * 8 + 8]) T.writes(C[i * 8 : i * 8 + 8]) B = T.alloc_buffer((8,), "float32") for j in range(0, 8): - with T.block([]) as []: + with T.block() as []: T.reads(A[i * 8 + j]) T.writes(B[j]) B[j] = A[i * 8 + j] + 1.0 for j in range(0, 8): - with T.block([]) as []: + with T.block() as []: T.reads(B[j]) T.writes(C[i * 8 + j]) C[i * 8 + j] = B[j] * 2.0 @@ -238,12 +238,12 @@ def complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (8, 8), "float32") C = T.match_buffer(c, (8, 8), "float32") for i in range(0, 8): - with T.block([]): + with T.block(): T.reads(A[0, 8]) T.writes(C[0, 8]) B = T.alloc_buffer((8, 8), "float32") for j in range(0, 4): - with T.block([]) as []: + with T.block() as []: D = T.alloc_buffer((8, 8), "float32") T.reads(A[i, j]) T.writes(B[i, j]) @@ -252,12 +252,12 @@ def complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(2, 4): T.store(B.data, j, A[i, j] + D[k, j]) for j in range(3, 5): - with T.block([]) as []: + with T.block() as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] for j in range(6, 8): - with T.block([]) as []: + with T.block() as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] @@ -268,12 +268,12 @@ def compacted_complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (8, 8), "float32") C = T.match_buffer(c, (8, 8), "float32") for i in range(0, 8): - with T.block([]): + with T.block(): T.reads(A[0, 8]) T.writes(C[0, 8]) B = T.alloc_buffer((1, 8), "float32") for j in range(0, 4): - with T.block([]) as []: + with T.block() as []: D = T.alloc_buffer((6, 1), "float32") T.reads(A[i, j]) T.writes(B[0, j]) @@ -282,12 +282,12 @@ def compacted_complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(2, 4): T.store(B.data, j, A[i, j] + D[k - 2, 0]) for j in range(3, 5): - with T.block([]) as []: + with T.block() as []: T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] for j in range(6, 8): - with T.block([]) as []: + with T.block() as []: T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] @@ -298,19 +298,19 @@ def match_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16)) C = T.match_buffer(c, (16, 16)) for i in range(0, 16): - with T.block([]): + with T.block(): A0 = T.match_buffer(A[i, 0:16], (16)) C0 = T.match_buffer(C[i, 0:16], (16)) B = T.alloc_buffer((16, 16)) - with T.block([]): + with T.block(): B0 = T.match_buffer(B[i, 0:16], (16)) for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: A1 = T.match_buffer(A0[j], ()) B1 = T.match_buffer(B0[j], ()) B1[()] = A1[()] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: C1 = T.match_buffer(C0[j], ()) B2 = T.match_buffer(B[i, j], ()) C1[()] = B2[()] * 2.0 @@ -321,19 +321,19 @@ def compacted_match_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16)) C = T.match_buffer(c, (16, 16)) for i in range(0, 16): - with T.block([]): + with T.block(): A0 = T.match_buffer(A[i, 0:16], (16)) C0 = T.match_buffer(C[i, 0:16], (16)) B = T.alloc_buffer((1, 16)) - with T.block([]): + with T.block(): B0 = T.match_buffer(B[0, 0:16], (16)) for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: A1 = T.match_buffer(A0[j], ()) B1 = T.match_buffer(B0[j], ()) B1[()] = A1[()] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: C1 = T.match_buffer(C0[j], ()) B2 = T.match_buffer(B[0, j], ()) C1[()] = B2[()] * 2.0 @@ -344,18 +344,18 @@ def storage_align_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i, j]) T.writes(B[i, j]) T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] * 2.0 @@ -366,7 +366,7 @@ def compacted_storage_align_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32") diff --git a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py index 287a30916520..ee323a64c50f 100644 --- a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py @@ -32,19 +32,19 @@ def elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with T.block([16, 16]) as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block(): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j) B[vi, vj] = A[vi, vj] + 1.0 for j in range(0, 16): - with T.block([16, 16]) as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block(): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j) C[vi, vj] = B[vi, vj] * 2.0 @@ -53,7 +53,7 @@ def substituted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer([16, 16], "float32") diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py new file mode 100644 index 000000000000..a91fa2591e00 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm + +from tvm.script import tir as T +from tvm.tir import stmt_functor + +# fmt: off +@T.prim_func +def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) + placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global") + for i0_i1_fused_3 in T.parallel(0, 28): + for i2_3, i3_3 in T.grid(28, 192): + T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), T.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): + for ax3_2 in T.serial(0, 16): + Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global") + T.store(Conv2dOutput_3, 0, 0, True) + for rc_3 in T.serial(0, 192): + T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast(T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*T.cast(T.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) + T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) +# fmt: on + + +def test_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(): + primfunc = fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2 + mod = tvm.IRModule.from_expr(primfunc) + mod = tvm.tir.transform.ConvertForLoopsToSerial()(mod) + + def verify_serial_loops(stmt): + if isinstance(stmt, tvm.tir.For): + assert stmt.kind == tvm.tir.ForKind.SERIAL + + for _, primfunc in mod.functions.items(): + stmt_functor.post_order_visit(primfunc.body, verify_serial_loops) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 21c896c7bb7e..eed82ebb9118 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -32,7 +32,7 @@ def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer([1, 16], "float32", scope="global") @@ -67,7 +67,7 @@ def compacted_gpu_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 4, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="threadIdx.x"): for i2 in T.thread_binding(0, 2, thread="vthread"): - with T.block([]): + with T.block(): T.reads(A[i0 * 4 + i1 * 2 + i2, 0:16]) T.writes(C[i0 * 4 + i1 * 2 + i2, 0:16]) B = T.alloc_buffer([1, 16], "float32", scope="local") @@ -108,17 +108,17 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - with T.block([]): + with T.block(): T.reads(A[i, m]) T.writes(C[i, m]) B = T.alloc_buffer((m,), "float32", scope="global") for j in range(0, m): - with T.block([]) as []: + with T.block() as []: T.reads(A[i, j]) T.writes(B[j]) B[j] = A[i, j] + 1.0 for j in range(0, m): - with T.block([]) as []: + with T.block() as []: T.reads(B[j]) T.writes(C[i, j]) C[i, j] = B[j] * 2.0 @@ -143,7 +143,7 @@ def compacted_predicate_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (32), "float32") for i, j in T.grid(5, 7): - with T.block([]) as []: + with T.block() as []: T.reads(A[i * 7 + j]) T.writes(C[i * 7 + j]) T.where(i * 7 + j < 32) @@ -166,7 +166,7 @@ def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (32), "float32") for x, y, z in T.grid(4, 1, 8): - with T.block([]) as []: + with T.block() as []: T.reads(A[x * 8 + y * 8 + z]) T.writes(C[x * 8 + y * 8 + z]) C[x * 8 + y * 8 + z] = A[x * 8 + y * 8 + z] + 1.0 @@ -187,7 +187,7 @@ def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (32), "float32") for i in range(0, 32): - with T.block([]) as []: + with T.block() as []: T.reads(A[i]) T.writes(D[i]) B = T.alloc_buffer((32,), scope="global") @@ -215,7 +215,7 @@ def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i0 in range(0, 4): - with T.block([]): + with T.block(): T.reads(A[i0 * 4 : i0 * 4 + 4, 0:16]) T.writes(C[i0 * 4 : i0 * 4 + 4, 0:16]) B = T.alloc_buffer([4, 16], "float32", strides=[17, 1], scope="global") diff --git a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py index 86bf87d5fa85..aa0448c3c682 100644 --- a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py +++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py @@ -17,6 +17,7 @@ import tvm import tvm.testing from tvm import te +from tvm.driver.build_module import schedule_to_module def test_copy2d(): @@ -53,11 +54,7 @@ def test_copy_pad(): ) s = te.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): @@ -77,11 +74,7 @@ def test_single_point_test(): B = te.compute((1,), lambda i: A[i], name="B") s = te.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): @@ -105,11 +98,8 @@ def test_copy_pad_split(): xo, xi = s[B].split(B.op.axis[0], factor=4) s[Apad].compute_at(s[B], xo) s[Apad].pragma(s[Apad].op.axis[0], "memcpy") - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py index c1c4fb3d2e8f..a4fd9404eee4 100644 --- a/tests/python/unittest/test_tir_transform_lower_init_block.py +++ b/tests/python/unittest/test_tir_transform_lower_init_block.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir, te +from tvm import te from tvm.script import tir as T # pylint: disable=no-self-argument @@ -28,10 +28,13 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) - with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: - with T.init(): - B[i] = T.float32(0) - B[i] += A[i, j, k] + for i0, j0 in T.grid(64, 64): + for k0 in T.serial(32, 64): + with T.block(): + i, j, k = T.axis.remap("SRR", [i0, j0, k0]) + with T.init(): + B[i] = T.float32(0) + B[i] += A[i, j, k] @tvm.script.ir_module @@ -41,10 +44,13 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) - with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: - if (j == 0) and (k == 32): - B[i] = T.float32(0) - B[i] += A[i, j, k] + for i0, j0 in T.grid(64, 64): + for k0 in T.serial(32, 64): + with T.block(): + i, j, k = T.axis.remap("SRR", [i0, j0, k0]) + if (j == 0) and (k == 32): + B[i] = T.float32(0) + B[i] += A[i, j, k] @tvm.script.ir_module @@ -54,12 +60,15 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) - with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: - BB = T.match_buffer(B[i], ()) - AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) - with T.init(): - BB[()] = T.float32(0) - BB[()] += AA[j, k] + for i0, j0 in T.grid(64, 64): + for k0 in T.serial(32, 64): + with T.block(): + i, j, k = T.axis.remap("SRR", [i0, j0, k0]) + BB = T.match_buffer(B[i], ()) + AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) + with T.init(): + BB[()] = T.float32(0) + BB[()] += AA[j, k] @tvm.script.ir_module @@ -69,17 +78,21 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) - with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: - BB = T.match_buffer(B[i], ()) - AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) - if (j == 0) and (k == 32): - BB[()] = T.float32(0) - BB[()] += AA[j, k] + for i0, j0 in T.grid(64, 64): + for k0 in T.serial(32, 64): + with T.block(): + i, j, k = T.axis.remap("SRR", [i0, j0, k0]) + BB = T.match_buffer(B[i], ()) + AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) + if (j == 0) and (k == 32): + BB[()] = T.float32(0) + BB[()] += AA[j, k] def test_lower_reduction(): origin_mod = WithInit mod = tvm.tir.transform.LowerInitBlock()(origin_mod) + print(mod.script()) tvm.ir.assert_structural_equal(mod, WithBranch, True) diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 15f994069abd..1ab6bdaad90a 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy + import tvm from tvm import te -import numpy +from tvm.driver.build_module import schedule_to_module def test_makeapi(): @@ -27,10 +29,7 @@ def test_makeapi(): C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") s = te.create_schedule(C.op) - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([n, A, B, C], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [n, A, B, C]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Apply( lambda f: f.with_attr( diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 9c511f1de6b9..cc78b84f9b4e 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -14,20 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te import numpy as np + +import tvm import tvm.testing +from tvm import te +from tvm.driver.build_module import schedule_to_module from tvm.topi.math import cast def run_passes(sch, args): - bounds = tvm.te.schedule.InferBound(sch) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(sch, args) return tvm.transform.Sequential( [ tvm.tir.transform.StorageFlatten(64), diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index cb8968cfc880..9b95266d3287 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te -from tvm import relay +from tvm import te, relay +from tvm.driver.build_module import schedule_to_module from tvm.tir import const @@ -39,11 +39,8 @@ def lower_sch(sch, args, target_bits): else: raise ValueError("args must be Tensor, Buffer or Var") sch = sch.normalize() - bounds = te.schedule.InferBound(sch) - stmt = te.schedule.ScheduleOps(sch, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(sch, args) mod = tvm.tir.transform.StorageFlatten(64)(mod) return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body @@ -66,7 +63,8 @@ def check(m, n, target_bits, target_dtype): # const shape # i32 -> i32 check(2, 2, 32, "int32") - check(2 ** 16, 2 ** 16, 32, "int32") # i32 + i32 is not promoted to i64 even if overflow + # i32 + i32 is not promoted to i64 even if overflow + check(2 ** 16, 2 ** 16, 32, "int32") # i64 -> i32 check(const(2, dtype="int64"), const(2, dtype="int64"), 32, "int32") check(const(2 ** 16, dtype="int64"), const(2 ** 16, dtype="int64"), 32, "int64") @@ -188,7 +186,7 @@ def check(m, n, target_bits, target_dtype): def test_relay_basic(): - engine = relay.backend.compile_engine.get() + engine = relay.backend.te_compiler.get() def check(shapex, shapey, target_bits, target_dtype): x = relay.var("x", shape=shapex) @@ -230,7 +228,7 @@ def check(shapex, shapey, target_bits, target_dtype): def test_relay_take(): - engine = relay.backend.compile_engine.get() + engine = relay.backend.te_compiler.get() def check(shape, index, target_bits, target_dtype): x = relay.var("x", shape=shape) diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index e55555305a09..c22f5f82ee10 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir, te +from tvm import te from tvm.script import tir as T @@ -31,12 +31,14 @@ def element_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16)) C = T.match_buffer(c, (16, 16)) B = T.alloc_buffer((16, 16)) - for i_0 in range(0, 16): - for j_0 in range(0, 16): - with T.block([16, 16]) as [i, j]: + for i0 in range(0, 16): + for j0 in range(0, 16): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) B[i, j] = A[i, j] + 1.0 - for j_0 in range(0, 16): - with T.block([16, 16]) as [i, j]: + for j0 in range(0, 16): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) C[i, j] = B[i, j] * 2.0 @@ -46,95 +48,112 @@ def transformed_element_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [16, 16]) for i_0 in range(0, 16): - with T.block([]): + with T.block(): T.reads([A[i_0, 0:16]]) T.writes([C[i_0, 0:16]]) B = T.alloc_buffer([16, 16]) for j_0 in T.serial(0, 16): - with T.block([16, 16], "") as [i, j]: - T.bind(i, i_0) - T.bind(j, j_0) + with T.block(): + i, j = T.axis.remap("SS", [i_0, j_0]) B[i, j] = A[i, j] + 1.0 for j_0 in T.serial(0, 16): - with T.block([16, 16], "") as [i, j]: - T.bind(i, i_0) - T.bind(j, j_0) + with T.block(): + i, j = T.axis.remap("SS", [i_0, j_0]) C[i, j] = B[i, j] * 2.0 @T.prim_func def original_func() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128]) as [i, j]: - A[i, j] = T.float32(0) - with T.block([32, 32, T.reduce_axis(0, 32)]) as [i, j, k]: - B = T.alloc_buffer((128, 128), "float32") - C = T.alloc_buffer((128, 128), "float32") - D = T.alloc_buffer((128, 128), "float32") - if k == 0: + for i0, j0 in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) + A[i, j] = T.float32(0) + for i0, j0, k0 in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + B = T.alloc_buffer((128, 128), "float32") + C = T.alloc_buffer((128, 128), "float32") + D = T.alloc_buffer((128, 128), "float32") + if k == 0: + for ii, jj in T.grid(4, 4): + B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in T.grid(4, 4): - B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] - for ii, jj in T.grid(4, 4): - for kk in range(0, 4): - B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] - for kk in range(0, 4): - B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += ( + D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + ) @T.prim_func def transformed_func() -> None: A = T.alloc_buffer([128, 128]) - with T.block([128, 128], "") as [i, j]: - A[i, j] = T.float32(0) - with T.block([32, 32, T.reduce_axis(0, 32)], "") as [i, j, k]: - B = T.alloc_buffer([128, 128]) - if k == 0: + for i0, j0 in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) + A[i, j] = T.float32(0) + for i0, j0, k0 in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + B = T.alloc_buffer([128, 128]) + if k == 0: + for ii, jj in T.grid(4, 4): + B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in T.grid(4, 4): - B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] - for ii, jj in T.grid(4, 4): - with T.block([], ""): - T.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) - T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) - C = T.alloc_buffer([128, 128]) - for kk in T.serial(0, 4): - B[((i * 4) + ii), ((j * 4) + jj)] = ( - B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)] - ) - for kk in T.serial(0, 4): - with T.block([], ""): - T.reads( - [ - B[((i * 4) + ii), ((j * 4) + jj)], - C[((i * 4) + ii), ((k * 4) + kk)], - ] - ) - T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) - D = T.alloc_buffer([128, 128]) - B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + ( - D[((j * 4) + jj), ((k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)] + with T.block(""): + T.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) + T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) + C = T.alloc_buffer([128, 128]) + for kk in T.serial(0, 4): + B[((i * 4) + ii), ((j * 4) + jj)] = ( + B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)] ) + for kk in T.serial(0, 4): + with T.block(""): + T.reads( + [ + B[((i * 4) + ii), ((j * 4) + jj)], + C[((i * 4) + ii), ((k * 4) + kk)], + ] + ) + T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) + D = T.alloc_buffer([128, 128]) + B[((i * 4) + ii), ((j * 4) + jj)] = B[ + ((i * 4) + ii), ((j * 4) + jj) + ] + ( + D[((j * 4) + jj), ((k * 4) + kk)] + * C[((i * 4) + ii), ((k * 4) + kk)] + ) @T.prim_func def match_buffer_func() -> None: C = T.alloc_buffer((128, 128)) - with T.block([128]) as [vi]: - C0 = T.match_buffer(C[vi, 0:128], (128)) - with T.block([128]) as [jj]: - C1 = T.match_buffer(C0[jj], ()) - C1[()] = 0 + for i in range(128): + with T.block(): + vi = T.axis.S(128, i) + C0 = T.match_buffer(C[vi, 0:128], (128)) + for j in range(128): + with T.block(): + jj = T.axis.S(128, j) + C1 = T.match_buffer(C0[jj], ()) + C1[()] = 0 @T.prim_func def transformed_match_buffer_func() -> None: for i in range(0, 128): - with T.block([128]) as [vi]: - T.bind(vi, i) + with T.block(): + vi = T.axis.S(128, i) C = T.alloc_buffer((128, 128)) C0 = T.match_buffer(C[vi, 0:128], (128)) - with T.block([128]) as [jj]: - C1 = T.match_buffer(C0[jj], ()) - C1[()] = 0 + for j in range(128): + with T.block(): + jj = T.axis.S(128, j) + C1 = T.match_buffer(C0[jj], ()) + C1[()] = 0 @T.prim_func @@ -143,9 +162,10 @@ def opaque_access(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [1024]) A_cache = T.alloc_buffer([1024]) for i in T.serial(0, 8): - with T.block([8]) as [vi]: - with T.block([8]) as [v]: - T.bind(v, vi) + with T.block(): + vi = T.axis.S(8, i) + with T.block(): + v = T.axis.S(8, vi) T.reads([A[(v * 128) : ((v * 128) + 128)]]) T.writes([A_cache[(v * 128) : ((v * 128) + 128)]]) T.evaluate( @@ -161,8 +181,8 @@ def opaque_access(a: T.handle, b: T.handle) -> None: ) ) for j in T.serial(0, 128): - with T.block([1024]) as [v]: - T.bind(v, ((vi * 128) + j)) + with T.block(): + v = T.axis.S(1024, vi * 128 + j) T.reads([A_cache[v]]) T.writes([B[v]]) B[v] = A_cache[v] @@ -173,12 +193,13 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024]) B = T.match_buffer(b, [1024]) for i in T.serial(0, 8): - with T.block([8]) as [vi]: + with T.block(): + vi = T.axis.S(8, i) T.reads(A[vi * 128 : vi * 128 + 128]) T.writes(B[vi * 128 : vi * 128 + 128]) A_cache = T.alloc_buffer([1024]) - with T.block([8]) as [v]: - T.bind(v, vi) + with T.block(): + v = T.axis.S(8, vi) T.reads([A[v * 128 : v * 128 + 128]]) T.writes([A_cache[v * 128 : v * 128 + 128]]) T.evaluate( @@ -187,8 +208,8 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: ) ) for j in T.serial(0, 128): - with T.block([1024]) as [v]: - T.bind(v, ((vi * 128) + j)) + with T.block(): + v = T.axis.S(1024, vi * 128 + j) T.reads([A_cache[v]]) T.writes([B[v]]) B[v] = A_cache[v] diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 37223493a8b5..a51e926155d3 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.driver.build_module import schedule_to_module from tvm.script import tir as T from tvm.relay import GlobalVar @@ -30,14 +31,10 @@ def test_flatten2(): s = te.create_schedule(A2.op) xo, xi = s[A2].split(A2.op.axis[0], 8) s[A1].compute_at(s[A2], xo) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="A") A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name="A2") - func = tvm.te.schedule.SchedulePostProcToPrimFunc([Ab, A2b], stmt, {A: Ab, A2: A2b}) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [Ab, A2b], binds={A: Ab, A2: A2b}) mod = tvm.tir.transform.StorageFlatten(64)(mod) @@ -70,12 +67,8 @@ def test_flatten_storage_align(): s = te.create_schedule(A2.op) s[A1].storage_align(A1.op.axis[0], 2, 1) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, A2]) mod = tvm.transform.Sequential( [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] )(mod) diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 9e738b136b17..5a91788283d6 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.driver.build_module import schedule_to_module def test_storage_share(): @@ -28,12 +29,7 @@ def test_storage_share(): B = te.compute((m, l), lambda i, j: B[i, j] + (t + 1), name="A%d" % t) s = te.create_schedule(B.op) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -169,12 +165,7 @@ def test_inplace_rule(): AA = te.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name="AA") B = te.compute((m,), lambda i: AA[i] + 1, name="B") s = te.create_schedule(B.op) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -206,11 +197,8 @@ def test_storage_combine(): s = te.create_schedule(B.op) for S in stages[:-1]: s[S].set_scope("global:tag") - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -238,10 +226,7 @@ def test_storage_combine_with_vectorization(): BB = s.cache_read(B, "global:tag", readers=[C]) CC = s.cache_write(C, "global:tag") s[CC].vectorize(s[CC].op.axis[0]) - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B, C]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.VectorizeLoop()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) @@ -285,11 +270,7 @@ def test_storage_share_gpu(): s[A[2 * t + 1]].compute_at(s[A[2 * t + 2]], tx) s[A[2 * t + 1]].set_scope("shared") - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A[0], A[-1]], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A[0], A[-1]]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) @@ -418,12 +399,7 @@ def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): A0L = s.cache_read(A0, scope_tb, [A2]) A1L = s.cache_read(A1, scope_tb, [A2]) A2L = s.cache_read(A2, scope_tb, [B]) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C, D], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B, C, D]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -511,12 +487,7 @@ def test_inplace_rule3(): s[B10].compute_inline() s = s.normalize() - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([B0, B1, B2, B3, B4, B5, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [B0, B1, B2, B3, B4, B5, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) diff --git a/tests/python/unittest/test_tvm_testing_features.py b/tests/python/unittest/test_tvm_testing_features.py index cbcdc4356250..c00fc02c4331 100644 --- a/tests/python/unittest/test_tvm_testing_features.py +++ b/tests/python/unittest/test_tvm_testing_features.py @@ -46,8 +46,11 @@ def test_device_parametrization(self, dev): self.devices_used.append(dev) def test_all_targets_used(self): - assert self.targets_used == self.enabled_targets - assert self.devices_used == self.enabled_devices + assert sorted(self.targets_used) == sorted(self.enabled_targets) + + def test_all_devices_used(self): + sort_key = lambda dev: (dev.device_type, dev.device_id) + assert sorted(self.devices_used, key=sort_key) == sorted(self.enabled_devices, key=sort_key) targets_with_explicit_list = [] @@ -70,9 +73,9 @@ def test_exclude_target(self, target): self.targets_with_exclusion.append(target) def test_all_nonexcluded_targets_ran(self): - assert self.targets_with_exclusion == [ - target for target in self.enabled_targets if not target.startswith("llvm") - ] + assert sorted(self.targets_with_exclusion) == sorted( + [target for target in self.enabled_targets if not target.startswith("llvm")] + ) run_targets_with_known_failure = [] @@ -85,7 +88,7 @@ def test_known_failing_target(self, target): assert "llvm" not in target def test_all_targets_ran(self): - assert self.run_targets_with_known_failure == self.enabled_targets + assert sorted(self.run_targets_with_known_failure) == sorted(self.enabled_targets) @tvm.testing.known_failing_targets("llvm") @tvm.testing.parametrize_targets("llvm") diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 7c521db21bb8..105b4a2d6a3f 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -26,10 +26,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = T.float32(0) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -39,12 +41,14 @@ def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j in T.grid(32, 32): - with T.block([32, 32], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) for ii, jj in T.grid(4, 4): C[vi * 4 + ii, vj * 4 + jj] = T.float32(0) for k in range(0, 32): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) for ii, jj, kk in T.grid(4, 4, 4): C[vi * 4 + ii, vj * 4 + jj] = ( C[vi * 4 + ii, vj * 4 + jj] @@ -58,12 +62,15 @@ def elementwise_with_root(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([]) as []: - with T.block([128, 128]) as [vi, vj]: - B[vi, vj] = A[vi, vj] + T.float32(1) - - with T.block([128, 128]) as [vi, vj]: - C[vi, vj] = B[vi, vj] + T.float32(1) + with T.block() as []: + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) def func_with_opaque_block(a: T.handle, b: T.handle, c: T.handle) -> None: @@ -71,12 +78,13 @@ def func_with_opaque_block(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([]) as []: - with T.block([]) as []: + with T.block() as []: + with T.block() as []: B[0, 0] = A[0, 0] + T.float32(1) - - with T.block([128, 128]) as [vi, vj]: - C[vi, vj] = B[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) @T.prim_func @@ -85,14 +93,18 @@ def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([]) as []: - with T.block([128, 128]) as [vi, vj]: - T.reads(A[vi, vj]) - B[vi, vj] = A[vi, vj] + T.float32(1) + with T.block() as []: + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + B[vi, vj] = A[vi, vj] + T.float32(1) - with T.block([128, 128]) as [vi, vj]: - T.writes(C[vi, vj]) - C[vi, vj] = B[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) def test_complete_matmul(): @@ -181,22 +193,23 @@ def func_with_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, (1,), "int32") out_buf = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - out_buf[vi, vj] = data_buf[vi, index_buf[0]] + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + out_buf[vi, vj] = data_buf[vi, index_buf[0]] @T.prim_func def expected_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): - with T.block([16, 16], "") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, i1) + with T.block(): + vi, vj = T.axis.remap("SS", [i0, i1]) T.reads([data_buf[vi, 0:16], index_buf[0]]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[vi, index_buf[0]] @@ -208,22 +221,23 @@ def func_with_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> index_buf = T.match_buffer(index, (1,), "int32") out_buf = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] @T.prim_func def expected_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): - with T.block([16, 16], "") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, i1) + with T.block(): + vi, vj = T.axis.remap("SS", [i0, i1]) T.reads([data_buf[0:16, 0:16], index_buf[0]]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] @@ -240,11 +254,11 @@ def test_complete_buffer_indices(): def match_buffer_func(a: T.handle) -> None: A = T.match_buffer(a, (16, 16)) for i in range(0, 16): - with T.block([]): + with T.block(): A0 = T.match_buffer(A[i, 0:16], (16)) - with T.block([]): + with T.block(): for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: A1 = T.match_buffer(A0[j], ()) A1[()] = 1.0 @@ -253,15 +267,15 @@ def match_buffer_func(a: T.handle) -> None: def expected_match_buffer_func(a: T.handle) -> None: A = T.match_buffer(a, (16, 16)) for i in range(0, 16): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, 0:16]) A0 = T.match_buffer(A[i, 0:16], (16)) - with T.block([]): + with T.block(): T.reads([]) T.writes(A0[0:16]) for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads([]) T.writes(A0[j]) A1 = T.match_buffer(A0[j], ()) @@ -272,6 +286,34 @@ def test_complete_match_buffer(): tvm.ir.assert_structural_equal(match_buffer_func, expected_match_buffer_func) +@T.prim_func +def alloc_buffer_func(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [2, 2], dtype="float32") + B = T.match_buffer(b, [2, 2], dtype="float32") + C = T.alloc_buffer([2, 2], dtype="float32") + A[(0, 0)] = T.float32(2) + C[(0, 0)] = A[(0, 0)] + B[(0, 0)] + B[(0, 0)] = C[(0, 0)] + + +@T.prim_func +def expect_alloc_buffer_func(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1) + B = T.match_buffer(b, [2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1) + with T.block("root"): + T.reads([]) + T.writes([]) + C = T.alloc_buffer([2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1) + A[(0, 0)] = T.float32(2) + C[(0, 0)] = A[(0, 0)] + B[(0, 0)] + B[(0, 0)] = C[(0, 0)] + + +def test_complete_alloc_buffer(): + rt_func = tvm.script.from_source(alloc_buffer_func.script(show_meta=True)) + tvm.ir.assert_structural_equal(alloc_buffer_func, expect_alloc_buffer_func) + + if __name__ == "__main__": test_complete_matmul() test_complete_matmul_original() @@ -279,3 +321,4 @@ def test_complete_match_buffer(): test_complete_part_region() test_complete_buffer_indices() test_complete_match_buffer() + test_complete_alloc_buffer() diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 99a22636b927..3098c86a7c2e 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -18,6 +18,7 @@ import pytest import sys import tvm +from tvm import tir from tvm.script import tir as T from tvm.ir.diagnostics import override_renderer import inspect @@ -155,33 +156,83 @@ def test_allocate_with_buffers(): check_error(allocate_with_buffers, 2) -def inconsistent_binding() -> None: - with T.block([128, 128]) as [vi]: # error +def inconsistent_binding_value() -> None: + for i, j in T.grid(16, 16): + vi, vj = T.axis.remap("SS", [i]) # error + T.evaluate(1.0) + + +def inconsistent_binding_type() -> None: + for i, j in T.grid(16, 16): + vi, vj = T.axis.remap("S", [i, j]) # error T.evaluate(1.0) def test_inconsistent_binding(): - check_error(inconsistent_binding, 2) + check_error(inconsistent_binding_value, 3) + check_error(inconsistent_binding_type, 3) + + +def error_remap_type() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("TT", [i, j]) # error + T.evaluate(1.0) + + +def error_remap_value() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i + j, j]) # error + T.evaluate(1.0) + + +def test_error_remap_args(): + check_error(error_remap_type, 4) + check_error(error_remap_value, 4) def invalid_block_axes(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") - with T.block([A]) as [vi]: # error - T.evaluate(1.0) + for i, j in T.grid(16, 16): + with T.block(): + vi = T.axis.S(i, A) # error + T.evaluate(1.0) def test_invalid_block_axes(): - check_error(invalid_block_axes, 3) + check_error(invalid_block_axes, 5) -def miss_block_bind() -> None: - with T.block([16, 16]) as [vi, vj]: # error - T.bind(vi, 1) - T.evaluate(1.0) +def duplicate_block_axes() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi = T.axis.S(16, i) + vi = T.axis.S(16, j) # error + T.evaluate(1.0) + + +def duplicate_block_axes_remap() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vi = T.axis.remap("SS", [i, j]) # error + T.evaluate(1.0) + + +def test_duplicate_block_axes(): + check_error(duplicate_block_axes, 5) + check_error(duplicate_block_axes_remap, 4) + + +def miss_block_bind_value() -> None: + for i, j in T.grid(128, 128): + with T.block(): + vi = T.axis.S(i) # error + T.evaluate(1.0) def test_miss_block_bind(): - check_error(miss_block_bind, 2) + check_error(miss_block_bind_value, 4) def invalid_loop_var() -> None: @@ -203,74 +254,99 @@ def test_inconsistent_grid(): def invalid_match_buffer_region() -> None: - with T.block([16, 16]) as [vi, vj]: - A = T.match_buffer(vi) # error - T.evaluate(1.0) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A = T.match_buffer(vi) # error + T.evaluate(1.0) def test_invalid_match_buffer_region(): - check_error(invalid_match_buffer_region, 3) + check_error(invalid_match_buffer_region, 5) def duplicate_buffer() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - A = T.alloc_buffer((128, 128), "float32") # error - T.evaluate(1.0) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A = T.alloc_buffer((128, 128), "float32") # error + T.evaluate(1.0) def test_duplicate_buffer(): - check_error(duplicate_buffer, 4) + check_error(duplicate_buffer, 6) def duplicate_reads() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - T.reads(A[0:8, 0:8]) - T.reads(A[0:16, 0:16]) # error - T.evaluate(1.0) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[0:8, 0:8]) + T.reads(A[0:16, 0:16]) # error + T.evaluate(1.0) def duplicate_writes() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - T.writes(A[0:8, 0:8]) - T.writes(A[0:16, 0:16]) # error - T.evaluate(1.0) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.writes(A[0:8, 0:8]) + T.writes(A[0:16, 0:16]) # error + T.evaluate(1.0) def duplicate_predicate() -> None: - with T.block([16, 16]) as [vi, vj]: - T.where(1) - T.where(0) # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.where(1) + T.where(0) # error def duplicate_annotations() -> None: - with T.block([16, 16]) as [vi, vj]: - T.block_attr({}) - T.block_attr({}) # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({}) + T.block_attr({}) # error def duplicate_init() -> None: - with T.block([16, 16]) as [vi, vj]: - with T.init(): - T.evaluate(1.0) - with T.init(): # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + with T.init(): + T.evaluate(1.0) + with T.init(): # error + T.evaluate(1.0) + + +def duplicate_axes() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + vi = T.axis.S(i, 16) # error T.evaluate(1.0) def test_duplicate_block_signature(): - check_error(duplicate_reads, 5) - check_error(duplicate_writes, 5) - check_error(duplicate_predicate, 4) - check_error(duplicate_annotations, 4) - check_error(duplicate_init, 5) + check_error(duplicate_reads, 7) + check_error(duplicate_writes, 7) + check_error(duplicate_predicate, 6) + check_error(duplicate_annotations, 6) + check_error(duplicate_init, 7) + check_error(duplicate_axes, 5) def opaque_access_during_complete(a: T.handle) -> None: # error A = T.match_buffer(a, (16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - T.evaluate(T.load("float32", A.data, vi * 16 + vj)) + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.evaluate(T.load("float32", A.data, vi * 16 + vj)) def test_opaque_access_during_complete(): @@ -279,55 +355,65 @@ def test_opaque_access_during_complete(): def convert_slice_to_bufferload() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = A[vi : vi + 2, vj] + 1 # error + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A[vi : vi + 2, vj] + 1 # error def test_convert_slice_to_bufferload(): - check_error(convert_slice_to_bufferload, 4) + check_error(convert_slice_to_bufferload, 6) def error_index_type() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = A[vi, 0.0] + 1 # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A[vi, 0.0] + 1 # error def error_bufferslice_index_type() -> None: A = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - C[vi, vj] = B[vi, A[0]] # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, A[0]] # error def test_error_index_type(): - check_error(error_index_type, 4) - check_error(error_bufferslice_index_type, 6) + check_error(error_index_type, 6) + check_error(error_bufferslice_index_type, 8) def error_index_with_stop() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = A[vi, 1:10] + 1 # error + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A[vi, 1:10] + 1 # error def error_bufferslice_index_with_stop() -> None: A = T.alloc_buffer((1,), "int32") B = T.alloc_buffer((16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - C[vi, vj] = B[vi, A[0:1]] # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, A[0:1]] # error def test_error_index_with_stop_slice(): - check_error(error_index_with_stop, 4) - check_error(error_bufferslice_index_with_stop, 6) + check_error(error_index_with_stop, 6) + check_error(error_bufferslice_index_with_stop, 8) def mismatch_args() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: + with T.block(): T.reads(A[0, 0], A[1, 1]) # error T.evaluate(1.0) @@ -338,8 +424,7 @@ def test_mismatch_args(): def special_stmt_except() -> None: A = T.alloc_buffer("(128, 128)", "float32") # error - with T.block([16, 16]) as [vi, vj]: - T.evaluate(1.0) + T.evaluate(1.0) def scope_handler_except() -> None: @@ -368,7 +453,7 @@ def test_tvm_exception_catch(): def buffer_shape_mismatch(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i, j * 4 : j * 4 + 4]]) sub_A = T.match_buffer( @@ -383,7 +468,7 @@ def test_match_buffer_shape_mismatch(): def high_dim_store() -> None: - with T.block([], "root"): + with T.block("root"): B = T.allocate([256], "float32", "global") for i, j in T.grid(16, 16): B[i, j] = 1.0 # error: Store is only allowed with one index @@ -393,6 +478,15 @@ def test_high_dim_store(): check_error(high_dim_store, 5) +def block_has_option_vars() -> None: + with T.block("root") as x: # error: block does not support option_vars + T.evaluate(0.0) + + +def test_block_has_option_vars(): + check_error(block_has_option_vars, 2) + + def check_error(func, rel_lineno): # Override the default renderer to accumulate errors errors = [] @@ -416,5 +510,79 @@ def render(e): ), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}" +# TODO(Siyuan): block iter errors. + + +@T.prim_func +def elementwise_not_affine(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in T.grid(128, 128, 128, 8): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + vl = T.axis.S(128, l * 16) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@T.prim_func +def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + C = T.alloc_buffer((128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(0, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in T.serial(0, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +def test_reorder_fail_block(): + sch = tir.Schedule(elementwise_not_affine, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(l, i) + expected_sub_error_message = ( + " # tir.Block#0\n" + ' with tir.block("B"):\n' + " ^^^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_reorder_fail_nested_loop_inner(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(k, i) + expected_sub_error_message = ( + " for i in tir.serial(0, 128):\n" + " # tir.For#0\n" + " for j in tir.serial(0, 128):\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_fuse_fail_nested_loop_outer(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.fuse(k, i) + expected_sub_error_message = ( + " # tir.For#1\n" + " for i in tir.serial(0, 128):\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" + " for j in tir.serial(0, 128):\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tvmscript_ops.py b/tests/python/unittest/test_tvmscript_ops.py index c55fd7b69282..82f0fa5c86bc 100644 --- a/tests/python/unittest/test_tvmscript_ops.py +++ b/tests/python/unittest/test_tvmscript_ops.py @@ -37,22 +37,25 @@ def get_valid_counts( out_buf = T.match_buffer(out, (1, 2500, 6), "float32") out_indices_buf = T.match_buffer(out_indices, (1, 2500), "int32") - with T.block([1], "init") as [vi]: + with T.block("init"): + vi = T.axis.S(1, 0) valid_count_buf[vi] = T.int32(0) - with T.block([2500], "update") as [vj]: - T.reads([data_buf[vi, vj, 6]]) - T.writes([valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6]]) - if (data_buf[vi, vj, score_index] > score_threshold) and ( - (id_index < 0) or (data_buf[vi, vj, id_index] >= T.float32(0)) - ): - for k in T.serial(0, 6): - out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] - out_indices_buf[vi, valid_count_buf[vi]] = vj - valid_count_buf[vi] = valid_count_buf[vi] + 1 - if vj >= valid_count_buf[vi]: - for k in T.serial(0, 6): - out_buf[vi, vj, k] = T.float32(-1) - out_indices_buf[vi, vj] = T.int32(-1) + for j in range(2500): + with T.block("update"): + vj = T.axis.S(2500, j) + T.reads([data_buf[vi, vj, 6]]) + T.writes([valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6]]) + if (data_buf[vi, vj, score_index] > score_threshold) and ( + (id_index < 0) or (data_buf[vi, vj, id_index] >= T.float32(0)) + ): + for k in T.serial(0, 6): + out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] + out_indices_buf[vi, valid_count_buf[vi]] = vj + valid_count_buf[vi] = valid_count_buf[vi] + 1 + if vj >= valid_count_buf[vi]: + for k in T.serial(0, 6): + out_buf[vi, vj, k] = T.float32(-1) + out_indices_buf[vi, vj] = T.int32(-1) def _check_get_valid_counts_with_numpy(f, dshape, score_threshold, id_index, score_index): @@ -101,5 +104,64 @@ def test_get_valid_counts_script_func(): _check_get_valid_counts_with_numpy(f, (1, 2500, 6), 0.0, 0, 1) +@T.prim_func +def alloc_zero_dim_buffer(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [], dtype="float32") + B = T.match_buffer(b, [], dtype="float32") + # body + # tir.with block("root") + C = T.alloc_buffer([], dtype="float32") + A[()] = T.float32(2) + C[()] = A[()] + B[()] + B[()] = C[()] + + +@T.prim_func +def alloc_zero_dim_buffer_block(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + B = T.match_buffer(b, (), "float32") + with T.block("root"): + T.reads([]) + T.writes([]) + C = T.alloc_buffer((), "float32") + A[()] = T.float32(2) + C[()] = A[()] + B[()] + B[()] = C[()] + + +def _check_alloc_zero_dim_buffer(f): + dtype = "float32" + ctx = tvm.cpu() + + np_data = np.zeros(shape=()).astype(dtype) + np_out = np.zeros(shape=()).astype(dtype) + tvm_data = tvm.nd.array(np_data, ctx) + tvm_out = tvm.nd.array(np_out, ctx) + + # np func exection + np_inter = np.array(1) + np_data[()] = 2.0 + np_inter[()] = np_data[()] + np_out[()] + np_out[()] = np_inter[()] + + # tvm func execution + f(tvm_data, tvm_out) + tvm.testing.assert_allclose(tvm_out.numpy(), np_out, rtol=1e-5) + + +def test_alloc_zero_dim_buffer_round_trip(): + func = alloc_zero_dim_buffer + func_with_block = alloc_zero_dim_buffer_block + rt_func = tvm.script.from_source(func.script(show_meta=True)) + rt_func_with_block = tvm.script.from_source(func_with_block.script(show_meta=True)) + rt_mod = tvm.build(rt_func, "llvm") + rt_mod_with_block = tvm.build(rt_func_with_block, "llvm") + tvm.ir.assert_structural_equal(func, func_with_block) + tvm.ir.assert_structural_equal(rt_func, rt_func_with_block) + _check_alloc_zero_dim_buffer(rt_mod) + _check_alloc_zero_dim_buffer(rt_mod_with_block) + + if __name__ == "__main__": test_get_valid_counts_script_func() + test_alloc_zero_dim_buffer_round_trip() diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 8058b96b024d..93b052ee1d96 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -2672,10 +2672,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = T.float32(0) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -2685,11 +2687,13 @@ def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -2699,11 +2703,14 @@ def element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * T.float32(2) - - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) @T.prim_func @@ -2712,9 +2719,9 @@ def predicate(b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (16, 16), "float32") for i, jo, ji in T.grid(16, 4, 5): - with T.block([16, 16], "update") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, jo * 4 + ji) + with T.block("update"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, jo * 4 + ji) T.where(jo * 4 + ji < 16) C[vi, vj] = B[vi, vj] + T.float32(1) @@ -2807,12 +2814,16 @@ def match_buffer_region(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16), "float32") B = T.match_buffer(b, (1), "float32") - with T.block([16, 4]) as [vi, vj]: - C = T.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) - with T.block([4]) as [vii]: - D = T.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) - for i, j in T.grid(4, 4): - B[0] += D[i, 0, j] + for i, j in T.grid(16, 4): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C = T.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) + for ii in range(4): + with T.block(): + vii = T.axis.S(4, ii) + D = T.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) + for i, j in T.grid(4, 4): + B[0] += D[i, 0, j] def test_match_buffer_region(): @@ -2844,8 +2855,8 @@ def block_elements(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (1, 1), "float32") - with T.block([1], "update") as [vi]: - T.bind(vi, 0) + with T.block("update"): + vi = T.axis.S(1, 0) T.where(True) T.reads(A[0:16, 0:16]) T.writes(B[0, 0]) @@ -2879,11 +2890,11 @@ def opaque_block(a: T.handle, b: T.handle) -> None: for i in range(16): for j in range(16): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, j]) A[i, j] = T.float32(0) - with T.block([]): + with T.block(): T.reads([A[i, 0:16]]) T.writes([B[i, 0:16]]) for j in range(16): @@ -2927,7 +2938,7 @@ def rank0_block(a: T.handle) -> None: B = T.alloc_buffer((), "float32") T.store(B.data, 0, T.load("float32", A.data, 0)) - with T.block([], "update") as []: + with T.block("update") as []: T.reads([A[()]]) T.writes([B[()]]) for i in range(1): @@ -2969,8 +2980,10 @@ def test_minmax(): def abs(a: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") - with T.block([128, 128], "A") as [vi, vj]: - A[vi, vj] = T.abs(A[vi, vj]) + for i, j in T.grid(128, 128): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = T.abs(A[vi, vj]) def test_abs(): @@ -3011,15 +3024,13 @@ def test_simplify_bracket(): @T.prim_func def var_with_same_name(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = 0 - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = 0 for i, j in T.grid(16, 16): - with T.block([16, 16]) as [vi, vj]: + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = 0 for i, j in T.grid(16, 16): - with T.block([16, 16]) as [vi, vj]: + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = 0 @@ -3029,14 +3040,10 @@ def test_same_name_var(): rt_func = tvm.script.from_source(out_str) tvm.ir.assert_structural_equal(func, rt_func) - assert out_str.count("with T.block([16, 16]) as [vi, vj]") == 4 + assert out_str.count('vi, vj = T.axis.remap("SS", [i, j])') == 2 assert out_str.find("vi_") == -1 assert out_str.find("vj_") == -1 - assert out_str.count("for i0, i1 in T.grid(16, 16)") == 2 - assert out_str.find("i0_") == -1 - assert out_str.find("i1_") == -1 - assert out_str.count("for i, j in T.grid(16, 16)") == 2 assert out_str.find("i_") == -1 assert out_str.find("i_") == -1 @@ -3047,11 +3054,13 @@ def while_loop(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") i = T.alloc_buffer((), "int32", scope="local") - with T.block([16]) as [vi]: - B[vi] = 0 - while i[()] < 10: - for j in range(16): - B[j] += A[j] + for ii in range(16): + with T.block(): + vi = T.axis.S(16, ii) + B[vi] = 0 + while i[()] < 10: + for j in range(16): + B[j] += A[j] def test_while_loop(): @@ -3086,5 +3095,65 @@ def test_primfunc_with_allocate_annotations(): tvm.ir.assert_structural_equal(func, rt_func, True) +# fmt: off +@T.prim_func +def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + threadIdx_x = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128, 128], dtype="float32") + for i in T.serial(0, 128): + T.launch_thread(threadIdx_x, 128) + reduce_temp0 = T.allocate([1], "float32", "local") + with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), T.load("float32", A.data, i * 128 + threadIdx_x), True, reduce_temp0, threadIdx_x, dtype="handle")) + + +@T.prim_func +def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + threadIdx_x = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128, 128], dtype="float32") + for i in T.serial(0, 128): + T.launch_thread(threadIdx_x, 128) + reduce_temp0 = T.allocate([1], "float32", "local") + with T.attr(T.comm_reducer(lambda x0, x1, y0, y1: (T.Select((x1 >= y1), x0, y0), T.Select((x1 >= y1), x1, y1)), [T.int32(-1), T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), T.load("float32", A.data, i * 128 + threadIdx_x), True, reduce_temp0, threadIdx_x, dtype="handle")) + + +@T.prim_func +def multiple_commreducer() -> None: + normal_reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_maxelem_cross_thread_reduction"): + T.attr(T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle")) + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_expsum_cross_thread_reduction"): + T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle")) +# fmt: on + + +def test_primfunc_with_single_reduce_group_commreducer(): + func = comm_reducer_single_reduce_group + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_primfunc_with_multiple_reduce_group_commreducer(): + func = comm_reducer_multiple_reduce_groups + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_primfunc_with_multiple_commreducer(): + func = multiple_commreducer + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh index 7138effe395a..dfd2a32165f1 100755 --- a/tests/scripts/task_ci_setup.sh +++ b/tests/scripts/task_ci_setup.sh @@ -30,7 +30,7 @@ set -o pipefail # echo "Addtiional setup in" ${CI_IMAGE_NAME} -python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.4.1 +python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.5.0 # Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in # Jenkinsfile. We expect config.cmake to be present from pack_lib(). diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index ecc8ba5d17b0..aba4663d5931 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -15,6 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +set -e +set -u set -o pipefail echo "Checking MyPy Type defs in the TensorIR schedule package." diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 2eb471cbc69f..765c84137730 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -83,11 +83,11 @@ cd .. rm -rf _docs mv docs/_build/html _docs rm -f _docs/.buildinfo -mkdir -p _docs/api -mv docs/doxygen/html _docs/api/doxygen -mv jvm/core/target/site/apidocs _docs/api/javadoc +mkdir -p _docs/reference/api +mv docs/doxygen/html _docs/reference/api/doxygen +mv jvm/core/target/site/apidocs _docs/reference/api/javadoc # mv rust/target/doc _docs/api/rust -mv web/dist/docs _docs/api/typedoc +mv web/dist/docs _docs/reference/api/typedoc echo "Start creating the docs tarball.." # make the tarball diff --git a/tests/scripts/task_python_microtvm.sh b/tests/scripts/task_python_microtvm.sh index 6632ebb1ca52..8de8b908ee09 100755 --- a/tests/scripts/task_python_microtvm.sh +++ b/tests/scripts/task_python_microtvm.sh @@ -23,13 +23,20 @@ set -x # NOTE(areusch): Adding to diagnose flaky timeouts source tests/scripts/setup-pytest-env.sh make cython3 -run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --zephyr-board=qemu_x86 + +# Zephyr +run_pytest ctypes python-microtvm-zephyr-qemu_x86 tests/micro/zephyr --zephyr-board=qemu_x86 +run_pytest ctypes python-microtvm-zephyr-qemu_riscv32 tests/micro/zephyr --zephyr-board=qemu_riscv32 +run_pytest ctypes python-microtvm-zephyr-qemu_riscv64 tests/micro/zephyr --zephyr-board=qemu_riscv64 + # Temporarily removing mps2_an512 from CI due to issue 8728: # https://github.com/apache/tvm/issues/8728 # run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --zephyr-board=mps2_an521 +# Arduino run_pytest ctypes python-microtvm-arduino apps/microtvm/arduino/template_project/tests run_pytest ctypes python-microtvm-arduino-nano33ble tests/micro/arduino --test-build-only --arduino-board=nano33ble run_pytest ctypes python-microtvm-arduino-due tests/micro/arduino --test-build-only --arduino-board=due +# STM32 run_pytest ctypes python-microtvm-stm32 tests/micro/stm32 diff --git a/vta/tutorials/README.txt b/vta/tutorials/README.txt index 3d3858b111ba..c1ff4ca0444d 100644 --- a/vta/tutorials/README.txt +++ b/vta/tutorials/README.txt @@ -1,3 +1,5 @@ +.. _vta-tutorials: + VTA Tutorials ============= This page contains tutorials about VTA and how to use TVM/Relay to target VTA. diff --git a/vta/tutorials/frontend/deploy_detection.py b/vta/tutorials/frontend/deploy_detection.py index 771801851a48..cbd22a752049 100644 --- a/vta/tutorials/frontend/deploy_detection.py +++ b/vta/tutorials/frontend/deploy_detection.py @@ -34,15 +34,15 @@ # # .. code-block:: bash # -# pip3 install "Pillow<7" +# pip3 install "Pillow<7" # # YOLO-V3-tiny Model with Darknet parsing have dependancy with CFFI and CV2 library, # we need to install CFFI and CV2 before executing this script. # -# pip3 install "Pillow<7" +# .. code-block:: bash # -# pip3 install cffi -# pip3 install opencv-python +# pip3 install cffi +# pip3 install opencv-python # # Now return to the python code. Import packages.