From 0b91e539ef788489511d94f60605ffea3d32b6c9 Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Mon, 11 Oct 2021 14:44:02 +0100 Subject: [PATCH 01/84] Initial Implementation of TIRToRuntime Target hook (#9190) * Initial Implementation of TIRToRuntime Target hook This is the initial implementation which wires in a test case for TIRToRuntime, in order to get this working I re-used `CodegenCHost` as it implements all of the `Op`s required from the lowered `PrimFunc`. Currently, the `IRModule` is non-unified but in future work it should definitely do so, I wanted to implement the basics here to get the infra in place. * Fix heterogeneous compute with multiple kDLCPU targets * Remove rogue te_compiler.h include --- .../modules/contrib/ExampleTargetHooks.cmake | 2 +- include/tvm/target/target_kind.h | 28 ++++++++ src/driver/driver_api.cc | 28 +++++++- .../example_target_hooks/relay_to_tir.cc | 19 ++++-- .../contrib/example_target_hooks/target.cc | 39 +++++++++++ .../example_target_hooks/tir_to_runtime.cc | 64 +++++++++++++++++++ src/target/codegen.cc | 5 ++ src/target/source/codegen_c_host.h | 2 +- tests/python/relay/test_target_hooks.py | 23 +++++++ 9 files changed, 199 insertions(+), 11 deletions(-) create mode 100644 src/relay/backend/contrib/example_target_hooks/target.cc create mode 100644 src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc diff --git a/cmake/modules/contrib/ExampleTargetHooks.cmake b/cmake/modules/contrib/ExampleTargetHooks.cmake index eb53dda133d2..e9003b02103e 100644 --- a/cmake/modules/contrib/ExampleTargetHooks.cmake +++ b/cmake/modules/contrib/ExampleTargetHooks.cmake @@ -15,5 +15,5 @@ # specific language governing permissions and limitations # under the License. -file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc) +file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/*.cc) list(APPEND COMPILER_SRCS ${EXAMPLE_TARGET_HOOKS_SRC}) diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 8a2bbcbd0121..9d8695a43aff 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -24,6 +24,7 @@ #ifndef TVM_TARGET_TARGET_KIND_H_ #define TVM_TARGET_TARGET_KIND_H_ +#include #include #include @@ -33,6 +34,33 @@ #include namespace tvm { + +class Target; + +/*! + * \brief RelayToTIR tvm::transform::Pass specific to a TargetKind + * + * Called before the default lowering passes. + * + * \param mod The module that an optimization pass runs on. + * \param pass_ctx The pass context that can provide information for the optimization. + * + * \return The transformed module. + */ +using FTVMRelayToTIR = transform::Pass; + +/*! + * \brief TIRToRuntime conversion specific to a TargetKind + * + * This function is responsible for scanning an IRModule for appropriate Target-specific functions + and generating a Runtime module representing the compiled output + * + * \param ir_module Unified IRModule + * \param target Target to filter on or retrieve arguments from + * \return Runtime Module containing compiled functions + */ +using FTVMTIRToRuntime = runtime::TypedPackedFunc; + namespace detail { template struct ValueTypeInfoMaker; diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index bfea3e7b67c0..2c6fbc2eb76d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -401,12 +401,21 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target auto opt_mixed = transform::Sequential(mixed_pass_list); mod_mixed = opt_mixed(std::move(mod_mixed)); + // We make an assumption here that the overriden host target + // can be used alongside the default host codegen based on device type + // this is so the correct code generator is used later instead of overriding the target. + // We need better support for inserting multiple kDLCPU targets as our current options + // are kDeviceKernelLaunch or not + Target overriden_host_target = target_host; + if (target->kind->device_type == target_host->kind->device_type) { + overriden_host_target = target; + } auto host_pass_list = { Filter([](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; }), - BindTarget(target_host), + BindTarget(overriden_host_target), tir::transform::LowerTVMBuiltin(), tir::transform::LowerCustomDatatypes(), tir::transform::LowerIntrin(), @@ -487,7 +496,9 @@ runtime::Module build(const Map& inputs_arg, const Target& tar for (const auto& it : inputs) { if (it.second.defined()) { - auto pair = SplitDevHostFuncs(it.second, it.first, target_host, pass_ctx); + const Target& target = it.first; + const IRModule& ir_module = it.second; + auto pair = SplitDevHostFuncs(ir_module, target, target_host, pass_ctx); auto& mhost = pair.first; auto& mdevice = pair.second; @@ -495,7 +506,17 @@ runtime::Module build(const Map& inputs_arg, const Target& tar ICHECK(mhost_all.defined()) << "The host module must be defined"; - mhost_all->Update(mhost); + // We don't want library modules going back into host codegen + // unless they're supposed to. Here if we overrode the target host + // to allow lowering previously we check that it's meant to be placed + // back into the host Module. + bool overrides_host_target = target->kind->device_type == target_host->kind->device_type; + bool non_host_target_kind = target->kind != target_host->kind; + if (overrides_host_target && non_host_target_kind) { + device_modules.push_back(codegen::Build(mhost, it.first)); + } else { + mhost_all->Update(mhost); + } if (mdevice->functions.size() != 0) { device_modules.push_back(codegen::Build(mdevice, it.first)); @@ -510,6 +531,7 @@ runtime::Module build(const Map& inputs_arg, const Target& tar mhost.Import(it); } } + return mhost; } diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 6d332803041d..cae20210ec4f 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -33,7 +33,9 @@ namespace example_target_hooks { class ConvertAddToSubtract : public MixedModeMutator { public: explicit ConvertAddToSubtract(IRModule ir_module, Target host_target) - : ir_module_(ir_module), host_target_(host_target) {} + : ir_module_(ir_module), + host_target_(host_target), + custom_target_(Target("example_target_hook")) {} IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); @@ -81,7 +83,15 @@ class ConvertAddToSubtract : public MixedModeMutator { tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), buffer_map, DictAttrs(dict_attrs)); - replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + + // Switch to TIRToRuntime hook for testing + Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); + if (tir_to_runtime) { + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_); + } else { + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + } + ir_module_->Add(new_global_var, replacement_func); } @@ -109,6 +119,7 @@ class ConvertAddToSubtract : public MixedModeMutator { public: IRModule ir_module_; Target host_target_; + Target custom_target_; }; transform::Pass RelayToTIR() { @@ -124,8 +135,4 @@ transform::Pass RelayToTIR() { } // namespace contrib } // namespace relay -TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) - .set_attr("RelayToTIR", - relay::contrib::example_target_hooks::RelayToTIR()); - } // namespace tvm diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc new file mode 100644 index 000000000000..75b161ad4499 --- /dev/null +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -0,0 +1,39 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +namespace tvm { + +namespace relay { +namespace contrib { +namespace example_target_hooks { +tvm::transform::Pass RelayToTIR(); +runtime::Module TIRToRuntime(IRModule mod, Target target); +} // namespace example_target_hooks +} // namespace contrib +} // namespace relay + +TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) + .set_attr("RelayToTIR", relay::contrib::example_target_hooks::RelayToTIR()) + .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime); + +} // namespace tvm diff --git a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc new file mode 100644 index 000000000000..36d801d349a7 --- /dev/null +++ b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../../../../target/source/codegen_c_host.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace example_target_hooks { + +using namespace tir; + +class CodeGenExampleTargetHook : public codegen::CodeGenCHost { + public: + /*! + * \brief Emit code that changes adds to multiplies for testing + */ + void VisitExpr_(const SubNode* op, std::ostream& os) final { + os << '('; + PrintExpr(op->a, os); + os << " * "; + PrintExpr(op->b, os); + os << ')'; + } +}; + +runtime::Module TIRToRuntime(IRModule mod, Target target) { + bool output_ssa = false; + bool emit_asserts = false; + CodeGenExampleTargetHook codegen; + Array function_names; + codegen.Init(output_ssa, emit_asserts, target->str()); + for (auto kv : mod->functions) { + auto prim_func = Downcast(kv.second); + auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + function_names.push_back(global_symbol.value()); + codegen.AddFunction(prim_func); + } + std::string code = codegen.Finish(); + return codegen::CSourceModuleCreate(code, "c", function_names); +} + +} // namespace example_target_hooks +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 5a4aa39f01b4..41221ad8a33e 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -47,6 +47,11 @@ runtime::Module Build(IRModule mod, Target target) { mod = tir::transform::SkipAssert()(mod); } + auto target_attr_map = tvm::TargetKind::GetAttrMap("TIRToRuntime"); + if (target_attr_map.count(target->kind)) { + return target_attr_map[target->kind](mod, target); + } + // the build function. std::string build_f_name = "target.build." + target->kind->name; const PackedFunc* bf = runtime::Registry::Get(build_f_name); diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 10a437a547c1..4ff1c6ef61ed 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -35,7 +35,7 @@ namespace tvm { namespace codegen { -class CodeGenCHost final : public CodeGenC { +class CodeGenCHost : public CodeGenC { public: CodeGenCHost(); void Init(bool output_ssa, bool emit_asserts, std::string target_str); diff --git a/tests/python/relay/test_target_hooks.py b/tests/python/relay/test_target_hooks.py index 4d7a7fcdc15b..5856dc1e1c69 100644 --- a/tests/python/relay/test_target_hooks.py +++ b/tests/python/relay/test_target_hooks.py @@ -49,5 +49,28 @@ def test_tir_external_generation(check_result): check_result(func, inputs, (8,), x_data - y_data) +@pytest.mark.parametrize("check_result", [check_aot_executor_result, check_graph_executor_result]) +def test_runtime_module_generation(check_result): + shape = (8,) + x_data = np.random.randint(255, size=shape).astype("float32") + y_data = np.random.randint(255, size=shape).astype("float32") + inputs = {"x": x_data, "y": y_data} + + x0 = relay.var("x0", shape=shape, dtype="float32") + y0 = relay.var("y0", shape=shape, dtype="float32") + z = x0 + y0 + func = relay.Function([x0, y0], z) + func = set_external_func_attr(func, "example_target_hook", "replace_add_with_subtract") + # Test hook to trigger TIRToRuntime code generation + func = func.with_attr("tir_to_runtime", True) + + x = relay.var("x", shape=(8,), dtype="float32") + y = relay.var("y", shape=(8,), dtype="float32") + call = relay.Call(func, [x, y]) + func = IRModule.from_expr(call) + + check_result(func, inputs, (8,), x_data * y_data) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 0fa860c8a1a1f3c86e7ad463291d3405f3914653 Mon Sep 17 00:00:00 2001 From: Christian Convey Date: Mon, 11 Oct 2021 10:06:36 -0400 Subject: [PATCH 02/84] [TVM] Add importer for ONNX QLinearMatMul op (#8952) * adds importer code * enables `test_qlinearmatmul_2D` unit test --- python/tvm/relay/frontend/common.py | 31 +++++ python/tvm/relay/frontend/onnx.py | 153 +++++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 1 - 3 files changed, 184 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 3a4897ad3166..825a586918f8 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -835,3 +835,34 @@ def lstm_cell( outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)] return outputs_list, hidden_state, cell_state + + +def ensure_scalar_shape(x): + """ + Assume that `x` is a tensor with one element (regardless of tensor rank). + Return a version of that tensor with rank 0. + """ + x_shape = infer_shape(x) + x_rank = len(x_shape) + + if x_rank == 0: + return x + + num_elem = np.prod(x_shape) + assert num_elem == 1, "Cannot squeeze tensor shape {} to scalar form.".format(x_shape) + + return _op.squeeze(x) + + +def try_resolve_var_to_const(x, graph_params): + """ + Try to resolve the value of tensor `x` to a specific value. + If successful, return a Const op with that value. + If unsuccessful, simply return `x`. + """ + if isinstance(x, _expr.Var) and x.name_hint in graph_params: + value = graph_params[x.name_hint].numpy() + dtype = infer_type(x).checked_type.dtype + return _op.const(value, dtype) + + return x diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 86cb178d0875..3479e1e7c36e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -40,6 +40,7 @@ from .common import ( AttrCvt, Renamer, + ensure_scalar_shape, fold_constant, get_name, get_relay_op, @@ -50,6 +51,7 @@ infer_value, lstm_cell, new_var, + try_resolve_var_to_const, unbind, ) @@ -3506,6 +3508,156 @@ def _impl_v10(cls, inputs, attr, params): return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype) +class QLinearMatMul(OnnxOpConverter): + """ + Operator converter for QLinearMatMul from Microsoft onnxruntime contrib opset. + + Limitations: + - Only supports 2D input tensors. + - Not guaranteed to meet the integer-overflow behavior stipulated in the + ONNX documentation for this operator. + """ + + @classmethod + def _impl_v10(cls, inputs, attr, params): + + # Some of the ops used below take scalar-like inputs, and may require either + # of the following: + # + # - the input is Const node (not merely an expression that *could* be reduced + # to a single Const at graph-compilation time) + # + # - the input has a specific dtype + # + # This function attempts to present 'x' in a form that meets both of those + # requirements. + def try_resolve_to_const_scalar(x, dtype_override=None): + x2 = try_resolve_var_to_const(x, params) + x3 = ensure_scalar_shape(x2) + + x_dtype = infer_type(x).checked_type.dtype + if (dtype_override is not None) and (dtype_override != x_dtype): + x4 = _op.cast(x3, dtype_override) + else: + x4 = x3 + + x5 = fold_constant(x4) + return x5 + + # Unpack the inputs and obtain some type info... + a, a_scale, a_zp, b, b_scale, b_zp, y_scale, y_zp = inputs + + a_type = infer_type(a).checked_type # 'T1' in ONNX doc for this op + a_scale_type = infer_type(a_scale).checked_type + a_zp_type = infer_type(a_zp).checked_type + + b_type = infer_type(b).checked_type # 'T2' in ONNX doc for this op + b_scale_type = infer_type(b_scale).checked_type + b_zp_type = infer_type(b_zp).checked_type + + y_scale_type = infer_type(y_scale).checked_type + y_zp_type = infer_type(y_zp).checked_type # 'T3' in ONNX doc for this op + + a_shape = infer_shape(a) + b_shape = infer_shape(b) + + # Verify type assumptions, based on the ONNX doc for this op... + assert a_type.dtype in ["int8", "uint8"] + assert a_scale_type.dtype == "float32" + assert a_zp_type.dtype == a_type.dtype + + assert b_type.dtype in ["int8", "uint8"] + assert b_scale_type.dtype == "float32" + assert b_zp_type.dtype == b_type.dtype + + assert y_scale_type.dtype == "float32" + assert y_zp_type.dtype in ["int8", "uint8"] + + # TODO: relax this limitation in a future version of this importer. + a_rank = len(a_shape) + b_rank = len(b_shape) + assert (a_rank == 2) and (b_rank == 2), ( + "QLinearMatMul importer currently requires both 'a' and 'b' tensors to be 2D, but" + " rank(a)={}, rank(b)={}".format(a_rank, b_rank) + ) + + # _qnn.op.dense requires the zero-point values to have dtype int32. + a_scale_scalar = try_resolve_to_const_scalar(a_scale) + a_zp_scalar = try_resolve_to_const_scalar(a_zp, "int32") + + b_scale_scalar = try_resolve_to_const_scalar(b_scale) + b_zp_scalar = try_resolve_to_const_scalar(b_zp, "int32") + + y_scale_scalar = try_resolve_to_const_scalar(y_scale) + y_zp_scalar = try_resolve_to_const_scalar(y_zp, "int32") + + # TODO: Confirm that we're using 'num_hidden_units' correctly / as intended with + # the '_qnn.op.dense' instance below. + num_hidden_units = infer_shape(b)[-1] + + # - Specify the matmul result dtype as int32, so that hopefully the matmul will use + # a 32-bit accumulator as seems to be required by the ONNX op's documentation. + # + # TL;DR: + # The ONNX documentation for this op is clear about acceptable overflow + # behavior during the matmul operation: + # - The scalar multiplication ops MAY NOT overflow. + # - The scalar addition ops, which sum the results of the scalar multiplication, + # MAY overflow, but if they do so, it must behave as one would expect during + # 32-bit integer-addition overflow. + # As of this writing, Relay's qnn.op.dense operator doesn't expose a way for us to + # express these constraints. + # + # TODO: Extend TVM / Relay / TIR / etc. to allow this kind of constraint to be + # expressed in a Relay graph. And then update this importer and various TVM + # backends accordingly. + matmul_result_dtype = "int32" + + matmul_result = _qnn.op.dense( + a, + _op.transpose(b), + a_zp_scalar, + b_zp_scalar, + a_scale_scalar, + b_scale_scalar, + num_hidden_units, + matmul_result_dtype, + ) + + # This information might only be found in the C++ code-comments for the + # dense.matmul op, but the quantized tensor returned by _qnn.op.dense + # has scale==(a_scale_scalar * b_scale_scalar), and zero_point==0. + # + # 'matmul_result_zp_scalar' has type 'int32' to satisfy input requirements + # of the [de/re]quantize ops below. + matmul_result_scale_scalar = fold_constant(_op.multiply(a_scale_scalar, b_scale_scalar)) + matmul_result_zp_scalar = _op.const(0, dtype="int32") + + # requantize requires y_scale to be constant, + # if y_scale is not constant, doing dequantize -> quantize + if isinstance(y_scale_scalar, _expr.Constant): + y = _qnn.op.requantize( + matmul_result, + matmul_result_scale_scalar, + matmul_result_zp_scalar, + y_scale_scalar, + y_zp_scalar, + axis=-1, + rounding="TONEAREST", + out_dtype=y_zp_type.dtype, + ) + else: + matmul_result_deq = _qnn.op.dequantize( + matmul_result, matmul_result_scale_scalar, matmul_result_zp_scalar, axis=0 + ) + + y = _qnn.op.quantize( + matmul_result_deq, y_scale_scalar, y_zp_scalar, axis=0, out_dtype=y_zp_type.dtype + ) + + return y + + class QLinearMul(OnnxOpConverter): """Operator converter for QLinearMul from Microsoft onnxruntime contrib opset.""" @@ -4234,6 +4386,7 @@ def _get_convert_map(opset): "QLinearConv": QLinearConv.get_converter(opset), "QLinearConcat": QLinearConcat.get_converter(opset), "QLinearAdd": QLinearAdd.get_converter(opset), + "QLinearMatMul": QLinearMatMul.get_converter(opset), "QLinearMul": QLinearMul.get_converter(opset), "QLinearSigmoid": QLinearSigmoid.get_converter(opset), "ConvInteger": ConvInteger.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 69bb44e360ff..2301747034dd 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4941,7 +4941,6 @@ def verify_eyelike(indata): "test_mvn", # This test fails llvm with a lowering error: "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded", - "test_qlinearmatmul_2D", "test_qlinearmatmul_3D", "test_range_float_type_positive_delta_expanded", "test_range_int32_type_negative_delta_expanded", From 8ba0451d3bff4d6486908d0f9dcf064fe916dd36 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Mon, 11 Oct 2021 16:37:54 +0100 Subject: [PATCH 03/84] Arm(R) Ethos(TM)-U NPU Depthwise2d operator support (#9209) * Arm(R) Ethos(TM)-U NPU Depthwise2d operator support This commit adds support for Depthwise2d primitive operator throughout the TVM stack including Relay legalization pass, operator definition, TE, TIR passes and translation into the command stream. Change-Id: If82b85f5d3b23cd214fe38babd724451bf95ef5b * Change depthwise2d to depthwise_conv2d And respond to other review comments. Change-Id: I58a9f28723750970d386b4d0ba62fa399c5c6181 * Make a line shorter and add a comment Change-Id: Idf4c078bf65e7ed31fe82a92bf334295a82b6ead * Change the order of imports Change-Id: Ic6c77af30a5b9cb68dcc0c173b95490965359481 * Whitespace change Change-Id: I7318bd8cfa5985b33fc7d020cc19057cc9498197 --- .../relay/backend/contrib/ethosu/legalize.py | 94 ++++++++ .../backend/contrib/ethosu/op/__init__.py | 1 + .../backend/contrib/ethosu/op/depthwise.py | 205 +++++++++++++++++ .../backend/contrib/ethosu/te/__init__.py | 1 + .../backend/contrib/ethosu/te/depthwise.py | 148 ++++++++++++ .../backend/contrib/ethosu/tir/depthwise.py | 116 ++++++++++ .../backend/contrib/ethosu/tir/passes.py | 2 + .../relay/backend/contrib/ethosu/tir/spec.py | 2 +- .../relay/backend/contrib/ethosu/tir/utils.py | 6 +- .../contrib/ethosu/tir_to_cs_translator.py | 49 ++++ .../relay/backend/contrib/ethosu/vela_api.py | 15 +- python/tvm/relay/op/contrib/ethosu.py | 82 ++++++- src/relay/op/contrib/ethosu/depthwise.cc | 212 ++++++++++++++++++ tests/python/contrib/test_ethosu/infra.py | 159 +++++++++++-- .../contrib/test_ethosu/test_codegen.py | 94 +++++++- .../contrib/test_ethosu/test_legalize.py | 139 +++++++++++- .../test_replace_depthwise_conv2d.py | 178 +++++++++++++++ .../test_ethosu/test_tir_to_cs_translator.py | 75 +++++++ .../test_ethosu/test_type_inference.py | 96 ++++++++ tests/python/driver/tvmc/test_compiler.py | 6 +- 20 files changed, 1631 insertions(+), 49 deletions(-) create mode 100644 python/tvm/relay/backend/contrib/ethosu/op/depthwise.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/te/depthwise.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py create mode 100644 src/relay/op/contrib/ethosu/depthwise.cc create mode 100644 tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py create mode 100644 tests/python/contrib/test_ethosu/test_type_inference.py diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index fd58da803623..b970aec62c6f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -208,6 +208,99 @@ def __call__(self, *args, **kwargs): pass +class EthosuDepthwiseConv2DRewriter(DFPatternCallback): + """Convert ethosu.qnn_depthwise_conv2d composite functions to ethosu_depthwise_conv2d + operators""" + + def __init__(self): + super().__init__(require_type=True) + self.pattern = ( + wildcard().has_attr( + {"Composite": ethosu_patterns.QnnDepthwiseConv2DParams.composite_name} + ) + )(wildcard()) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = ethosu_patterns.QnnDepthwiseConv2DParams(post.op.body) + params.ifm.tensor = post.args[0] + channels_map = { + "NHWC": 3, + } + if str(params.ofm.layout) not in channels_map.keys(): + raise UnsupportedLayout(str(params.ofm.layout)) + kernel_shape_map = { + "HWOI": params.weights.shape[0:2], + } + if str(params.weights.layout) not in kernel_shape_map.keys(): + raise UnsupportedLayout(str(params.weights.layout)) + + weights_values = params.weights.values + weights_values_ohwi = np.moveaxis(weights_values, [0, 1, 2, 3], [1, 2, 0, 3]) + + activation = "NONE" + # Activations requiring LUT is currently not supported, so setting it to an empty list + lut = relay.const([], "int8") + clip_min = 0 + clip_max = 0 + if params.activation: + activation = ethosu_patterns.QnnDepthwiseConv2DParams.activation_map[ + params.activation.op.name + ] + if activation == "CLIP": + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) + scale_bias = vela_api.pack_biases( + biases=params.biases.tensor.data.asnumpy(), + ifm_scale=params.ifm.q_params.scale_f32, + ifm_dtype=np.dtype(params.ifm.dtype), + weight_scales=params.weights.q_params.scale_f32, + ofm_scale=params.ofm.q_params.scale_f32, + is_activation_tanh_or_sigmoid=activation in ["TANH", "SIGMOID"], + ) + + ethosu_depthwise_conv2d = ethosu_ops.ethosu_depthwise_conv2d( + post.args[0], # IFM + relay.const(weights_values_ohwi, params.weights.values.dtype), + relay.const(scale_bias, "uint8"), + lut, + float(params.ifm.q_params.scale_f32), + int(params.ifm.q_params.zero_point), + int(params.weights.q_params.zero_point), + float(params.ofm.q_params.scale_f32), + int(params.ofm.q_params.zero_point), + kernel_shape_map[str(params.weights.layout)], + params.ofm.shape[channels_map[str(params.ofm.layout)]], + strides=params.strides, + padding=params.padding, + dilation=params.dilation, + activation=activation, + clip_min=clip_min, + clip_max=clip_max, + upscale="NONE", + ifm_layout=str(params.ifm.layout), + ofm_layout=str(params.ofm.layout), + ) + return ethosu_depthwise_conv2d + + +@ir.transform.module_pass(opt_level=1) +class LegalizeEthosUDepthwiseConv2D: + """This is the pass that wraps the EthosUDepthwiseConv2DRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(EthosuDepthwiseConv2DRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + @ir.transform.module_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -220,6 +313,7 @@ def transform_module( ) -> tvm.ir.IRModule: mod = LegalizeSplit()(mod) mod = LegalizeEthosUConv2D()(mod) + mod = LegalizeEthosUDepthwiseConv2D()(mod) return mod def __call__(self, *args, **kwargs): diff --git a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py index 0406298f23f4..1063db6a04c5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py @@ -17,3 +17,4 @@ "Relay operators for the Arm(R) Ethos(TM)-U NPU" from .convolution import ethosu_conv2d +from .depthwise import ethosu_depthwise_conv2d diff --git a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py new file mode 100644 index 000000000000..abcddf90b97c --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py @@ -0,0 +1,205 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Relay operator for depthwise convolution""" +from typing import Tuple + +import tvm +from tvm.relay.op import _make +from tvm.topi.generic import schedule_injective +from tvm.relay.op.op import OpStrategy +from tvm.relay.op import strategy as _strategy + +from ..te import depthwise_conv2d_compute + + +def _extract_ethosu_depthwise_conv2d_params(attrs, args): + """Get the parameters necessary to construct a ethosu_depthwise_conv2d compute TE + from a ethosu_depthwise_conv2d Relay call.""" + ifm = args[0] + weight = args[1] + scale_bias = args[2] + lut = args[3] + ifm_scale = attrs.ifm_scale + ifm_zero_point = attrs.ifm_zero_point + weight_zero_point = attrs.weight_zero_point + ofm_scale = attrs.ofm_scale + ofm_zero_point = attrs.ofm_zero_point + strides = attrs.strides + padding = attrs.padding + dilation = attrs.dilation + activation = attrs.activation + clip_min = attrs.clip_min + clip_max = attrs.clip_max + upscale = attrs.upscale + ifm_layout = attrs.ifm_layout + ofm_layout = attrs.ofm_layout + + return ( + ifm, + weight, + scale_bias, + lut, + ifm_scale, + ifm_zero_point, + weight_zero_point, + ofm_scale, + ofm_zero_point, + strides, + padding, + dilation, + activation, + clip_min, + clip_max, + upscale, + ifm_layout, + ofm_layout, + ) + + +@tvm.ir.register_op_attr("contrib.ethosu.depthwise_conv2d", "FTVMCompute") +def create_ethosu_depthwise_conv2d_compute(attrs, args, out_type): + """Create an ethosu_depthwise_conv2d compute op.""" + params = _extract_ethosu_depthwise_conv2d_params(attrs, args) + op = depthwise_conv2d_compute(*params) + return [op] + + +@tvm.ir.register_op_attr("contrib.ethosu.depthwise_conv2d", "FTVMStrategy") +def depthwise_conv2d_strategy_ethosu(attrs, inputs, out_type, target): + strategy = OpStrategy() + strategy.add_implementation( + create_ethosu_depthwise_conv2d_compute, + _strategy.wrap_topi_schedule(schedule_injective), + name="ethosu_depthwise_conv2d", + ) + return strategy + + +def ethosu_depthwise_conv2d( + ifm: tvm.relay.Expr, + weight: tvm.relay.Expr, + scale_bias: tvm.relay.Expr, + lut: tvm.relay.Expr, + ifm_scale: float, + ifm_zero_point: int, + weight_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + kernel_shape: Tuple[int, int], + ofm_channels: int, + strides: Tuple[int, int] = (1, 1), + padding: Tuple[int, int, int, int] = (0, 0, 0, 0), + dilation: Tuple[int, int] = (1, 1), + activation: str = "NONE", + clip_min: int = 0, + clip_max: int = 0, + upscale: str = "NONE", + ifm_layout: str = "NHWC", + ofm_layout: str = "NHWC", +) -> tvm.relay.Call: + """This is a quantized 2D depthwise convolution operation as supported + by the NPU. It accepts either NHWC or NHCWB16 format + for the input data and OHWI format for the kernel weights. + + Reference: https://developer.arm.com/documentation/102420/0200/ + + Note that the per-channel weight scale and bias tensor must be + packed together into a combined tensor of uint80s. This is represented + in TVM by a (channels, 10) tensor of type uint8. For more detail, + refer to the Technical Reference Manual linked above. + + Parameters + ---------- + ifm : tvm.relay.Expr + The Input Feature Map tensor (IFM). + weight : tvm.relay.Expr + The weight tensor. + scale_bias : tvm.relay.Expr + The packed per-channel weight scale and bias tensor. + lut : tvm.relay.Expr + The look-up table values to use if activation = "LUT" + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + weight_zero_point : int + The quantization zero point for the weight tensor. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + kernel_shape : tuple of int + The 2 dimensional kernel shape as (kernel_height, kernel_width). + ofm_channels : int + The number of OFM channels. + strides : tuple of int, optional + The 2 dimensional strides as (stride_height, stride_width). + padding : tuple of int, optional + The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + dilation : tuple of int, optional + The 2 dimensional dilation as (dilation_height, dilation_width). + activation : str, optional + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform + the activation function. + clip_min : int, optional + The minimum clipping value if activation = "CLIP" + clip_max : int, optional, + The maximum clipping value if activation = "CLIP" + upscale : str, optional + The 2x2 upscaling mode to apply to the Input Feature Map tensor. + "NONE" - no upscaling. + "NEAREST" - upscale using nearest neighbour. + "ZEROS" - upscale using zeros. + ifm_layout : str, optional + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_layout : str, optional + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + out : tvm.relay.Call + A call to the ethosu_depthwise_conv2d op. + + """ + return _make.ethosu_depthwise_conv2d( + ifm, + weight, + scale_bias, + lut, + ifm_scale, + ifm_zero_point, + weight_zero_point, + ofm_scale, + ofm_zero_point, + kernel_shape, + ofm_channels, + strides, + padding, + dilation, + activation, + clip_min, + clip_max, + upscale, + ifm_layout, + ofm_layout, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py index 7ca5de3c160c..5dcdd4dcf602 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py @@ -17,3 +17,4 @@ """Tensor Expressions for the NPU""" from .convolution import * +from .depthwise import * diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py new file mode 100644 index 000000000000..35ae7f9a700a --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Tensor Expressions for depthwise convolutions""" +from typing import Tuple, Union, List + +from tvm import te +from .dma import dma_ofm_compute, dma_ifm_compute + + +def depthwise_conv2d_compute( + ifm: te.Tensor, + weight: te.Tensor, + scale_bias: te.Tensor, + lut: te.Tensor, + ifm_scale: float, + ifm_zero_point: int, + weight_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + strides: Tuple[int, int], + padding: Tuple[int, int, int, int], + dilation: Union[Tuple[int, int], List[int]], + activation: str, + clip_min: int, + clip_max: int, + upscale: str, + ifm_layout: str, + ofm_layout: str, +) -> te.Tensor: + """A compute operator representing the capabilities of 2D convolution for the NPU. + + Parameters + ---------- + ifm : te.Tensor + The Input Feature Map tensor (IFM). + weight : te.Tensor + The weight tensor. + scale_bias : te.Tensor + The packed per-channel weight scale and bias tensor. + lut : te.Tensor + The look-up table values to use if activation = "LUT". + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + weight_zero_point : int + The quantization zero point for the weight tensor. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + strides : tuple + The 2 dimensional strides as (stride_height, stride_width). + padding : tuple + The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + dilation : Union[int, tuple, list] + The 2 dimensional dilation as (dilation_height, dilation_width). + activation : str + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + clip_min : int + The minimum clipping value if activation = "CLIP". + clip_max : int + The maximum clipping value if activation = "CLIP". + upscale : str + The 2x2 upscaling mode to apply to the Input Feature Map tensor. + "NONE" - no upscaling. + "NEAREST" - upscale using nearest neighbour. + "ZEROS" - upscale using zeros. + ifm_layout : str + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_layout : str + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + te.Tensor + The OFM tensor. + + """ + assert ifm.shape[0] == 1, f"Only batch size 1 is supported" + assert ifm_layout in {"NHWC", "NHCWB16"} + assert ofm_layout in {"NHWC", "NHCWB16"} + + stride_h, stride_w = strides + dilation_h, dilation_w = dilation + channels, kernel_h, kernel_w, _ = weight.shape + + # Compute operation for the IFM DMA pipeline + dmaed_ifm = dma_ifm_compute(ifm, ifm_layout, ifm_zero_point, ifm_scale, channels, padding) + + # 2D Depthwise Convolution compute operation + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + ofm_height = (dmaed_ifm.shape[1] - dilated_kernel_h) // stride_h + 1 + ofm_width = (dmaed_ifm.shape[2] - dilated_kernel_w) // stride_w + 1 + rh = te.reduce_axis((0, kernel_h), name="ry") + rw = te.reduce_axis((0, kernel_w), name="rx") + + depthwise_conv2d_attrs = { + "op": "ethosu_depthwise_conv2d", + "weight_zero_point": weight_zero_point, + "activation": activation, + "upscale": upscale, + "clip_min": clip_min, + "clip_max": clip_max, + "stride_h": stride_h, + "stride_w": stride_w, + "dilation_h": dilation_h, + "dilation_w": dilation_w, + } + + depthwise = te.compute( + (1, ofm_height, ofm_width, channels), + lambda nn, hh, ww, cc: te.sum( + dmaed_ifm( + nn, hh * stride_h + rh * dilation_h, ww * stride_w + rw * dilation_w, cc + ).astype(ifm.dtype) + * weight[cc, rh, rw, 0].astype(ifm.dtype) + # This is a trick to load 10 elements of the scale_bias at once, not accurate maths + + (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype), + axis=[rh, rw], + ), + name="ethosu_depthwise_conv2d", + attrs=depthwise_conv2d_attrs, + ) + + # Compute operation for the OFM DMA pipeline + return dma_ofm_compute(depthwise, ofm_layout, ofm_zero_point, ofm_scale, channels) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py new file mode 100644 index 000000000000..27111a970b27 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the depthwise convolution operators in TIR.""" +from typing import Dict, Tuple +import tvm +from ..vela_api import SCALE_BIAS_LENGTH +from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores +from .dma import get_ifm_params, get_ofm_params +from .spec import ( + SerialKernel, + SerialAddressRange, + SerialActivation, + Serial2DDepthwise, +) + + +def get_depthwise_conv2d_params( + stmt: tvm.tir.AttrStmt, + producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], + consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], +) -> Tuple[Serial2DDepthwise, tvm.tir.Var, tvm.tir.Var]: + """Get the parameters necessary to construct a call_extern for a depthwise_conv2d. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a depthwise loop nest. + producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] + A dictionary to associate pointers with the loop nest + that produces their values. + consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + Serial2DDepthwise + The parameters needed to construct a 2D depthwise. + output_pointer : tvm.tir.Var + The output pointer of the convolution operation. + replace_pointer : tvm.tir.Var + The output pointer of the DMA write operation, which is to replace + the convolution output pointer. + + """ + attrs, body = get_op_attrs(stmt) + _, _, _, _, _, inner = get_outer_loops(body, "NHWC") + rh = inner + rw = rh.body + # loads = [output, input, weights, scale_bias, scale_bias] + loads = get_loads(rw.body) + # stores = [output] + stores = get_stores(rw.body) + input_pointer = loads[1].buffer_var + output_pointer = stores[0].buffer_var + # Get feature map info + serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) + serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + # Get kernel info + serial_kernel = SerialKernel( + width=int(rw.extent), + height=int(rh.extent), + stride_w=int(attrs["stride_w"]), + stride_h=int(attrs["stride_h"]), + dilation_w=int(attrs["dilation_w"]), + dilation_h=int(attrs["dilation_h"]), + ) + # Get scale_bias info + scale_bias_load = loads[3] + scale_bias_base = get_base_address(scale_bias_load.index) + serial_scale_bias = SerialAddressRange( + address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base), + length=SCALE_BIAS_LENGTH * serial_ofm[3], + ) + # Get weight info + weight_load = loads[2] + weight_base = get_base_address(weight_load.index) + serial_weight = SerialAddressRange( + address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base), + length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1], + ) + # Get activation info + serial_activation = SerialActivation( + op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] + ) + + return ( + Serial2DDepthwise( + ifm=serial_ifm, + ofm=serial_ofm, + kernel=serial_kernel, + weight=serial_weight, + weight_zero_point=attrs["weight_zero_point"], + scale_bias=serial_scale_bias, + padding=serial_padding, + activation=serial_activation, + upscale="NONE", + ), + output_pointer, + replace_pointer, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 1af44962c141..8bb410e986c7 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -21,6 +21,7 @@ import tvm from tvm.relay.backend.contrib.ethosu import vela_api from .convolution import get_conv2d_params +from .depthwise import get_depthwise_conv2d_params from .transform import get_copy_params from .utils import get_weights_pointer, get_scale_bias_pointer @@ -52,6 +53,7 @@ def ReplaceOperators(): op_map = { "ethosu_conv2d": get_conv2d_params, "ethosu_copy": get_copy_params, + "ethosu_depthwise_conv2d": get_depthwise_conv2d_params, } pointer_to_producer = {} pointer_to_consumer = {} diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py index 3ecbcd5f3cdc..ff019c7783db 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py @@ -203,7 +203,7 @@ def __init__( class Serial2DDepthwise(SerializableFormat): """Specialization class to retrieve arguments of - a ethosu.depthwise2d tir extern call on a predefined ordering""" + a ethosu.depthwise_conv2d TIR extern call on a predefined ordering""" def __init__( self, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py index 7d6fd3bf82d8..ccfc2dfbfc48 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -23,7 +23,8 @@ # TODO(@mbaret): Formalise this with a specification def get_weights_pointer(tir_extern_call): """Get the weights pointer from a NPU extern call if it exists""" - if tir_extern_call.args[0] == "ethosu_conv2d": + supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"] + if tir_extern_call.args[0] in supported_ops: return tir_extern_call.args[41].buffer_var return None @@ -31,7 +32,8 @@ def get_weights_pointer(tir_extern_call): # TODO(@mbaret): Formalise this with a specification def get_scale_bias_pointer(tir_extern_call): """Get the scale_bias pointer from a NPU extern call if it exists""" - if tir_extern_call.args[0] == "ethosu_conv2d": + supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"] + if tir_extern_call.args[0] in supported_ops: return tir_extern_call.args[44].buffer_var return None diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 4b28dc5b191e..408eab6427ca 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -299,6 +299,7 @@ def translate_ethosu_tir_extern_call(tir_extern_call): supported_extern_calls = { "ethosu_conv2d": translate_ethosu_conv2d, "ethosu_copy": translate_ethosu_copy, + "ethosu_depthwise_conv2d": translate_ethosu_depthwise_conv2d, } ext_call_type = tir_extern_call.args[0].value assert ext_call_type in supported_extern_calls.keys(), f"{ext_call_type} is not yet supported" @@ -408,6 +409,54 @@ def _create_npu_op_conv2d(serial_2d_convolution): return npu_conv2d_op, weights_zero_point +def translate_ethosu_depthwise_conv2d(tir_extern_call): + """This function will translate a tir extern_call + as produced by Relay to TIR compilation. + + Parameters + ---------- + tir_extern_call : tvm.tir.Call + This should be a tir external call that has an agreed upon ordering + for NPU TIR Compiler. See Serial2DDepthwise in + tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. + + Returns + ------- + ethosu.vela.api.NpuDepthWiseOperation + The vela object containing the params of ethosu_depthwise_conv2d + weights_zero_point : int + The zero point of the weights + """ + serial_object = spec.create_serial_object(spec.Serial2DDepthwise, tir_extern_call.args[1:]) + return _create_npu_op_depthwise_conv2d(serial_object) + + +def _create_npu_op_depthwise_conv2d(serial_2d_depthwise): + npu_depthwise_conv2d_op = vapi.NpuConvDepthWiseOperation() + + npu_depthwise_conv2d_op.ifm = _create_npu_feature_map(serial_2d_depthwise.ifm) + npu_depthwise_conv2d_op.ofm = _create_npu_feature_map(serial_2d_depthwise.ofm) + npu_depthwise_conv2d_op.kernel = _create_npu_kernel(serial_2d_depthwise.kernel) + npu_depthwise_conv2d_op.weights = [_create_npu_address_range(serial_2d_depthwise.weight)] + weights_zero_point = np.int64(serial_2d_depthwise.weight_zero_point.value) + npu_depthwise_conv2d_op.biases = [_create_npu_address_range(serial_2d_depthwise.scale_bias)] + npu_depthwise_conv2d_op.padding = _create_npu_padding(serial_2d_depthwise.padding) + + npu_depthwise_conv2d_op.activation = _create_npu_activation(serial_2d_depthwise.activation) + if ( + npu_depthwise_conv2d_op.activation + and npu_depthwise_conv2d_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + ): + _convert_clip_bounds(npu_depthwise_conv2d_op) + + npu_depthwise_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_depthwise.upscale) + target_accel_type = vela_api.get_target_accel_type() + block_config = vela_api.get_optimal_block_config(npu_depthwise_conv2d_op, target_accel_type) + npu_depthwise_conv2d_op.block_config = block_config + + return npu_depthwise_conv2d_op, weights_zero_point + + def _create_npu_feature_map(serial_feature_map): """This is a helper function to capture a list of arguments to create Vela NpuFeatureMap object diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index 5009c3157c77..6523352a0eea 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -130,18 +130,22 @@ def encode_weights(tir_extern_call, values, accel_type): bytearray Compressed weights """ - supported_ops = ["ethosu_conv2d"] + supported_ops = { + "ethosu_conv2d": tirtocs.translate_ethosu_conv2d, + "ethosu_depthwise_conv2d": tirtocs.translate_ethosu_depthwise_conv2d, + } op = str(tir_extern_call.args[0].value) - assert op in supported_ops - npu_op, weights_zero_point = tirtocs.translate_ethosu_conv2d(tir_extern_call) + assert op in supported_ops.keys() + npu_op, weights_zero_point = supported_ops[op](tir_extern_call) block_config = get_optimal_block_config(npu_op, accel_type) # The weight layout is assumed to be flat OHWI, always. assert len(values.shape) == 1 + is_depthwise = op == "ethosu_depthwise_conv2d" shape_ohwi = ( npu_op.ofm.shape.depth, npu_op.kernel.height, npu_op.kernel.width, - npu_op.ifm.shape.depth, + 1 if is_depthwise else npu_op.ifm.shape.depth, ) assert values.size == np.prod(shape_ohwi) values = np.reshape(values, shape_ohwi) @@ -154,8 +158,7 @@ def encode_weights(tir_extern_call, values, accel_type): block_depth=block_config.depth, dilation=(npu_op.kernel.dilation_x, npu_op.kernel.dilation_y), accel_type=accel_type, - # TODO(@manupa-arm): change this when we support depthwise - is_depthwise=False, + is_depthwise=is_depthwise, ) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 85ddfd9a7ec8..4369376b5689 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -192,11 +192,11 @@ def __init__(self, func_body: tvm.relay.Function): bias_add = requantize_op.args[0] qnn_conv2d = bias_add.args[0] data_layout = qnn_conv2d.attrs.data_layout - kernel_layout = qnn_conv2d.attrs.kernel_layout + self.kernel_layout = qnn_conv2d.attrs.kernel_layout # We consider the weights & biases as params as it should be a Constant self.weights = TensorParams( qnn_conv2d.args[QConv2DArgs.WEIGHTS.value], - kernel_layout, + self.kernel_layout, qnn_conv2d.args[QConv2DArgs.WEIGHTS_SCALE.value], qnn_conv2d.args[QConv2DArgs.WEIGHTS_ZERO_POINT.value], ) @@ -219,16 +219,18 @@ def __init__(self, func_body: tvm.relay.Function): requantize_op.args[RequantArgs.OFM_SCALE.value], requantize_op.args[RequantArgs.OFM_ZERO_POINT.value], ) - self.padding = qnn_conv2d.attrs.padding - self.strides = qnn_conv2d.attrs.strides - self.dilation = qnn_conv2d.attrs.dilation + attrs = qnn_conv2d.attrs + self.padding = attrs.padding + self.strides = attrs.strides + self.dilation = attrs.dilation self.activation = activation + self.channels = attrs.channels # If groups are equal to channel, its a depthwise_conv2d - self.groups = qnn_conv2d.attrs.groups + self.groups = attrs.groups self.is_depthwise = False channels_axis = {"HWIO": 3, "HWOI": 2} - if qnn_conv2d.attrs.groups == self.weights.shape[channels_axis[kernel_layout]]: + if self.groups == self.weights.shape[channels_axis[self.kernel_layout]]: self.is_depthwise = True def is_valid(self) -> bool: @@ -253,10 +255,52 @@ def is_valid(self) -> bool: legal_groups = [1, self.ofm.shape[3]] if self.groups not in legal_groups: return False - # This should be a valid QnnDepthwise2DParams, not QnnConv2DParams + # This should be a valid QnnDepthwiseConv2DParams, not QnnConv2DParams return not self.is_depthwise +class QnnDepthwiseConv2DParams(QnnConv2DParams): + """ + This class will parse a call to a ethosu.depthwise_conv2d composite function + and extract the parameter information. + """ + + composite_name = "ethosu.depthwise_conv2d" + # The hardware only supports padding upto the numbers as follows + padding_bounds = [31, 31, 32, 32] + + def __init__(self, func_body: tvm.relay.expr.Call): + QnnConv2DParams.__init__(self, func_body) + + def is_valid(self): + """ + Checks whether QnnDepthwiseConv2D + activation function has compatible attributes with HW + """ + tensor_params = [self.weights, self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params): + return False + if not check_weights(self.weights, self.dilation): + return False + if not check_bias(self.biases): + return False + if not check_strides(self.strides): + return False + if not check_batch_size(self.ifm): + return False + if not check_dilation(self.dilation): + return False + if not check_padding(self.padding, self.padding_bounds): + return False + if self.weights.layout != "HWOI": + return False + # only depth multiplier of size 1 is supported + if self.weights.shape[3] != 1: + return False + if not self.is_depthwise: + return False + return True + + def qnn_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ This function creates the pattern for qnn.conv2D with optional fused RELU activation. @@ -272,6 +316,21 @@ def qnn_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: return clip_or_req +def qnn_depthwise_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for depthwise qnn.conv2D with optional fused RELU activation. + """ + qnn_conv2d = is_op("qnn.conv2d")( + wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + ).has_attr({"kernel_layout": "HWOI"}) + bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant()) + req = is_op("qnn.requantize")( + bias_add, is_constant(), is_constant(), is_constant(), is_constant() + ) + clip_or_req = req.optional(is_op("clip")) + return clip_or_req + + @register_pattern_table("ethosu") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -279,7 +338,12 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal QnnConv2DParams.composite_name, qnn_conv2d_pattern(), lambda pat: QnnConv2DParams(pat).is_valid(), - ) + ), + ( + QnnDepthwiseConv2DParams.composite_name, + qnn_depthwise_conv2d_pattern(), + lambda pat: QnnDepthwiseConv2DParams(pat).is_valid(), + ), ] diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc new file mode 100644 index 000000000000..fa73645d45de --- /dev/null +++ b/src/relay/op/contrib/ethosu/depthwise.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/contrib/ethosu/depthwise.cc + * \brief Depthwise convolution 2D operator definition for the Arm(R) Ethos(TM)-U NPU + */ +#include +#include +#include +#include +#include + +#include "../../../qnn/utils.h" +#include "../../nn/convolution.h" +#include "common.h" + +namespace tvm { +namespace relay { +namespace op { +namespace contrib { +namespace ethosu { + +/*! \brief Attributes used by the Ethos(TM)-U NPU depthwise operator */ +struct EthosuDepthwiseConv2DAttrs : public tvm::AttrsNode { + double ifm_scale; + int ifm_zero_point; + int weight_zero_point; + double ofm_scale; + int ofm_zero_point; + Array kernel_shape; + IndexExpr ofm_channels; + Array strides; + Array padding; + Array dilation; + String activation; + int clip_min; + int clip_max; + String upscale; + String ifm_layout; + String ofm_layout; + + TVM_DECLARE_ATTRS(EthosuDepthwiseConv2DAttrs, "relay.attrs.EthosuDepthwiseConv2DAttrs") { + TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ifm_zero_point) + .describe("The quantization zero point for the Output Feature Map tensor."); + TVM_ATTR_FIELD(weight_zero_point) + .describe("The quantization zero point for the weight tensor."); + TVM_ATTR_FIELD(ofm_scale).describe("The quantization scale for the Output Feature Map tensor."); + TVM_ATTR_FIELD(ofm_zero_point) + .describe("The quantization zero point for the Output Feature Map tensor."); + TVM_ATTR_FIELD(kernel_shape) + .describe("The 2 dimensional kernel shape as (kernel_height, kernel_width).") + .set_default(NullValue >()); + TVM_ATTR_FIELD(ofm_channels) + .describe("The number of OFM channels.") + .set_default(NullValue()); + TVM_ATTR_FIELD(strides) + .describe("The 2 dimensional strides as (stride_height, stride_width).") + .set_default(Array({1, 1})); + TVM_ATTR_FIELD(padding) + .describe("The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right)") + .set_default(Array({0, 0, 0, 0})); + TVM_ATTR_FIELD(dilation) + .describe("The 2 dimensional dilation as (dilation_height, dilation_width).") + .set_default(Array({1, 1})); + TVM_ATTR_FIELD(activation) + .describe( + "Description: The activation function to use." + "'NONE' - no activation function." + "'CLIP' - clip the output between clip_min and clip_max." + "'TANH - tanh activation function." + "'SIGMOID' - sigmoid activation function." + "'LUT' - use a look-up table to perform the activation function.") + .set_default("NONE"); + TVM_ATTR_FIELD(clip_min) + .describe("The minimum clipping value if activation = CLIP.") + .set_default(0); + TVM_ATTR_FIELD(clip_max) + .describe("The maximum clipping value if activation = CLIP.") + .set_default(0); + TVM_ATTR_FIELD(upscale) + .describe( + "The 2x2 upscaling mode to apply to the Input Feature Map tensor. " + "'NONE' - no upscaling. " + "'NEAREST' - upscale using nearest neighbour. " + "'ZEROS' - upscale using zeros.") + .set_default("NONE"); + TVM_ATTR_FIELD(ifm_layout) + .set_default("NHWC") + .describe("The layout of the Input Feature Map tensor. Can be 'NHWC' or 'NHCWB16'."); + TVM_ATTR_FIELD(ofm_layout) + .set_default("NHWC") + .describe("The layout of the Output Feature Map tensor. Can be 'NHWC' or 'NHCWB16'."); + } +}; + +TVM_REGISTER_NODE_TYPE(EthosuDepthwiseConv2DAttrs); + +bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 5); + const auto* ifm = types[0].as(); + const auto* weight = types[1].as(); + const auto* scale_bias = types[2].as(); + if (ifm == nullptr || weight == nullptr) return false; + + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr."; + ICHECK(ifm->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8)) + << "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for ifm but was " + << ifm->dtype; + ICHECK(weight->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8)) + << "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for weight but was " + << weight->dtype; + ICHECK(scale_bias->dtype == DataType::UInt(8)) + << "Expected ethosu_depthwise_conv2d type(uint8) for scale_bias but was " + << scale_bias->dtype; + + // Collect the ifm, weight and ofm tensors for using in the inference function + Array tensor_types = {types[0], types[1], types[4]}; + + // Assign weight type {ofm_channels, kernel_height, kernel_width, 1} + reporter->Assign(types[1], TensorType({param->ofm_channels, param->kernel_shape[0], + param->kernel_shape[1], weight->shape[3]}, + weight->dtype)); + + // Assign ofm type + auto ofm_shape = + EthosuInferKernelOutput(ifm->shape, param->ifm_layout, param->ofm_layout, param->kernel_shape, + param->ofm_channels, param->dilation, param->strides, param->padding); + + reporter->Assign(types[4], TensorType(ofm_shape, ifm->dtype)); + + return true; +} + +Expr MakeEthosuDepthwiseConv2D(Expr ifm, Expr weight, Expr scale_bias, Expr lut, double ifm_scale, + int ifm_zero_point, int weight_zero_point, double ofm_scale, + int ofm_zero_point, Array kernel_shape, + IndexExpr ofm_channels, Array strides, + Array padding, Array dilation, + String activation, int clip_min, int clip_max, String upscale, + String ifm_layout, String ofm_layout) { + auto attrs = make_object(); + attrs->ifm_scale = ifm_scale; + attrs->ifm_zero_point = ifm_zero_point; + attrs->weight_zero_point = weight_zero_point; + attrs->ofm_scale = ofm_scale; + attrs->ofm_zero_point = ofm_zero_point; + attrs->kernel_shape = std::move(kernel_shape); + attrs->ofm_channels = std::move(ofm_channels); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->activation = std::move(activation); + attrs->clip_min = clip_min; + attrs->clip_max = clip_max; + attrs->upscale = std::move(upscale); + attrs->ifm_layout = std::move(ifm_layout); + attrs->ofm_layout = std::move(ofm_layout); + static const Op& op = Op::Get("contrib.ethosu.depthwise_conv2d"); + return Call(op, {ifm, weight, scale_bias, lut}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.ethosu_depthwise_conv2d") + .set_body_typed(MakeEthosuDepthwiseConv2D); + +RELAY_REGISTER_OP("contrib.ethosu.depthwise_conv2d") + .describe(R"code(Arm(R) Ethos(TM)-U NPU 2D quantized depthwise operator. + +This Relay operator corresponds to the hardware-implemented quantized +depthwise operation found on Ethos(TM)-U NPUs. It accepts either NHWC or NHCWB16 format +for the input data (input feature map, or IFM) and OHWI format for the kernel weights. + +- **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels) + NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16) +- **weight**: (ofm_channels, kernel_shape[0], kernel_shape[1], 1 (depth multiplier)) +- **scale_bias**: (ofm_channels, 10) +- **ofm**: (1, ofm_height, ofm_width, ofm_channels) + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(4) + .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).") + .add_argument("weight", "Tensor", "The weight tensor.") + .add_argument("scale_bias", "Tensor", "The packed per-channel weight scale and bias tensor.") + .add_argument("lut", "Tensor", "The look-up table values to use if activation = 'LUT'") + .set_support_level(11) + .add_type_rel("EthosuDepthwiseConv2D", EthosuDepthwiseConv2DRel); + +} // namespace ethosu +} // namespace contrib +} // namespace op +} // namespace relay +} // namespace tvm diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 8b0d3063a696..01a7ceb9ed56 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -29,7 +29,9 @@ import os import struct import numpy +import math from enum import IntEnum +import tensorflow as tf from ethosu.vela.register_command_stream_generator import CmdMode from ethosu.vela.register_command_stream_generator import cmd0 @@ -66,26 +68,6 @@ def __init__(self): self.npu_ops = set() -def parse_relay_tflite_model(tflite_model, input_tensor, input_shape, input_dtype): - mod_, params_ = relay.frontend.from_tflite( - tflite_model, - shape_dict={input_tensor: input_shape}, - dtype_dict={input_tensor: input_dtype}, - ) - return mod_, params_ - - -def parse_tflite_model(model_file): - try: - import tflite - - return tflite.Model.GetRootAsModel(model_file, 0) - except AttributeError: - import tflite.Model - - return tflite.Model.Model.GetRootAsModel(model_file, 0) - - def print_payload(payload): cmds = deserialize_command_stream(payload) for cmd_val in cmds: @@ -270,6 +252,58 @@ def flatten_numpy_data(data): return reshaped_data +class InputGenerator: + def __init__(self, random_state): + self._random_state = random_state + + def generate(self, size, dtype): + if dtype == numpy.float32: + print("random float32") + return self._random_state.uniform(-1, 1, size).astype(dtype) + else: + print("random (u)int min=%d max=%d", numpy.iinfo(dtype).min, numpy.iinfo(dtype).max) + low = numpy.iinfo(dtype).min + high = numpy.iinfo(dtype).max + 1 + return self._random_state.randint(low, high, size, dtype) + + +def generate_ref_data_tflite(model): + """ + This method generates reference data by running the specified model on tflite with random input data. + The random input data and generated output data are returned. + """ + expected_output_data = {} + interpreter = tf.lite.Interpreter(model_content=model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Initialize random generators with a fixed seed to get deterministic results + seed = 0 + random_state = numpy.random.RandomState(seed) + + inputgen = InputGenerator(random_state) + + # Generate input data + input_data = { + input_detail["name"]: inputgen.generate( + input_detail["shape"], + input_detail["dtype"], + ) + for input_detail in input_details + } + for index, value in enumerate(input_data.values()): + interpreter.set_tensor(index, value) + interpreter.invoke() + + expected_output_data = [ + interpreter.get_tensor(output_detail["index"]) for output_detail in output_details + ] + + return input_data, expected_output_data + + def generate_weights_data(shape, dtype): size = 1 for dim in shape: @@ -278,7 +312,7 @@ def generate_weights_data(shape, dtype): def get_convolutional_args(call, include_buffers=False, remove_constants=False): - """A method to extract the arguments from conv2d or depthwise2d extern call.""" + """A method to extract the arguments from conv2d or depthwise_conv2d extern call.""" args = call.args conv_args = [] remove_indices = [0] @@ -299,6 +333,44 @@ def get_convolutional_args(call, include_buffers=False, remove_constants=False): return conv_args +def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation=[1, 1]): + assert len(strides) == 2 + assert len(dilation) == 2 + assert len(kernel_shape) == 2 + if padding.lower() == "valid": + h = math.ceil((ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0]) / strides[0]) + w = math.ceil((ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1]) / strides[1]) + if padding.lower() == "same": + h = math.ceil(ifm_shape[1] / strides[0]) + w = math.ceil(ifm_shape[2] / strides[1]) + ofm_shape = [ifm_shape[0], h, w, ifm_shape[3]] + return ofm_shape + + +def compute_padding_shape(ifm_shape, ofm_shape, padding, kernel_shape, strides, dilation=[1, 1]): + assert len(strides) == 2 + assert len(dilation) == 2 + assert len(kernel_shape) == 2 + if padding.lower() == "valid": + return [0, 0, 0, 0] + if padding.lower() == "same": + effective_kernel_shape = [ + dilation[0] * (kernel_shape[0] - 1) + 1, + dilation[1] * (kernel_shape[1] - 1) + 1, + ] + pad_along_height = max( + (ofm_shape[1] - 1) * strides[0] + effective_kernel_shape[0] - ifm_shape[1], 0 + ) + pad_along_width = max( + (ofm_shape[2] - 1) * strides[1] + effective_kernel_shape[1] - ifm_shape[2], 0 + ) + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + return [pad_top, pad_left, pad_bottom, pad_right] + + def make_ethosu_conv2d( ifm, ifm_channels, @@ -343,3 +415,48 @@ def make_ethosu_conv2d( ofm_layout=ofm_layout, ) return conv + + +def make_ethosu_depthwise_conv2d( + ifm, + channels, + kernel_shape, + padding, + strides, + dilation, + activation="NONE", + ifm_layout="NHWC", + ofm_layout="NHWC", + weight_dtype="int8", +): + # params + weight_shape = (channels, kernel_shape[0], kernel_shape[1], 1) + padding = get_pad_tuple(padding, kernel_shape) + + scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8") + scale_bias = relay.const(scale_bias_data, dtype="uint8") + weight_data = generate_weights_data(weight_shape, weight_dtype) + weight = relay.const(weight_data, dtype=weight_dtype) + depthwise = ethosu_ops.ethosu_depthwise_conv2d( + ifm, + weight, + scale_bias, + lut=relay.const([], dtype="int8"), + ifm_scale=0.6, + ifm_zero_point=11, + weight_zero_point=13, + ofm_scale=0.26, + ofm_zero_point=15, + kernel_shape=kernel_shape, + ofm_channels=channels, + strides=strides, + padding=padding, + dilation=dilation, + activation=activation, + clip_min=15 if activation == "CLIP" else 0, + clip_max=105 if activation == "CLIP" else 0, + upscale="NONE", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + return depthwise diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 1944de5f94c0..4949d6814ab2 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -18,14 +18,12 @@ import pytest pytest.importorskip("ethosu.vela") -import os import numpy as np -import pathlib +import tflite.Model import tvm -import tvm.micro as micro +import tensorflow as tf from tvm import relay -from tvm.relay.backend.contrib import ethosu from tvm.relay.backend.contrib.ethosu import util from tvm.relay.op.contrib.ethosu import partition_for_ethosu from tests.python.relay.aot.aot_test_utils import generate_ref_data @@ -168,5 +166,93 @@ def create_graph_activation(input_tensor_name, input_tensor_shape, input_tensor_ infra.verify_source(compiled_models, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)]) +@pytest.mark.parametrize( + "kernel_shape, activation", + [((3, 3), "relu"), ((1, 2), None)], +) +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 2)), ((3, 2), (1, 1))]) +def test_tflite_depthwise_conv2d( + accel_type, + ifm_shape, + kernel_shape, + padding, + strides, + dilation, + activation, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def depthwise_conv2d(self, x): + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + # The input strides to the TensorFlow API needs to be of shape 1x4 + tf_strides = [1, strides[0], strides[1], 1] + op = tf.nn.depthwise_conv2d( + x, weight, strides=tf_strides, padding=padding, dilations=dilation + ) + if activation: + op = tf.nn.relu(op) + return op + + model = Model() + concrete_func = model.depthwise_conv2d.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 911a0e6eefc6..b9a588d4aec0 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -20,15 +20,33 @@ pytest.importorskip("ethosu.vela") import numpy as np +import tensorflow as tf +import tflite.Model import tvm from tvm import relay -from tvm.relay.backend.contrib import ethosu from tvm.relay.backend.contrib.ethosu import legalize, preprocess -from tvm.relay.dataflow_pattern import * -from tvm.relay.op.contrib.ethosu import * +from tvm.relay import dataflow_pattern +from tvm.relay.op.contrib import ethosu +from tvm.relay.build_module import bind_params_by_name from . import relay_ir_builder +from . import infra + + +def partition_ethosu_by_table(mod, pattern_table): + """In case only the legalization part is supported for an operator, we don't + want to add the operator's pattern to the pattern table so that the compiler + wouldn't attempt to offload an operator without full stack support.""" + mod = relay.transform.InferType()(mod) + mod = relay.transform.MergeComposite(pattern_table)(mod) + mod = relay.transform.AnnotateTarget("ethosu")(mod) + mod = relay.transform.MergeCompilerRegions()(mod) + mod = relay.transform.InferType()(mod) + mod = relay.transform.PartitionGraph()(mod) + mod = relay.transform.InferType()(mod) + mod = preprocess.preprocess_ext_io()(mod) + return mod def test_split_indices_legalize(): @@ -294,7 +312,7 @@ def verify_linear(ext_func, conv2d_params): ] for test_case in test_cases: mod, conv_params = test_case[0](*test_case[1]) - mod = partition_for_ethosu(mod) + mod = ethosu.partition_for_ethosu(mod) mod = legalize.LegalizeEthosUConv2D()(mod) verify_linear(mod["tvmgen_default_ethosu_main_0"], conv_params) @@ -327,12 +345,123 @@ def create_graph_single_unsupported_ifm_layout( for test_case in test_cases: mod, conv_params = test_case[0](*test_case[1]) - mod = partition_for_ethosu(mod) + mod = ethosu.partition_for_ethosu(mod) with pytest.raises( tvm._ffi.base.TVMError, match="EthosUCodegenError: Unsupported Layout NCHW" ): mod = legalize.LegalizeEthosUConv2D()(mod) +@pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)]) +@pytest.mark.parametrize("kernel_shape", [(7, 3), (22, 5)]) +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) +@pytest.mark.parametrize("activation", ["RELU", None]) +def test_tflite_depthwise_conv_2d_legalize( + ifm_shape, kernel_shape, padding, strides, dilation, activation +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def depthwise_conv2d(self, x): + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + # The input strides to the TensorFlow API needs to be of shape 1x4 + tf_strides = [1, strides[0], strides[1], 1] + op = tf.nn.depthwise_conv2d( + x, weight, strides=tf_strides, padding=padding, dilations=dilation + ) + if activation: + op = tf.nn.relu(op) + return op + + model = Model() + concrete_func = model.depthwise_conv2d.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + op = ext_func.body + ofm_channels = op.attrs.ofm_channels + + # check IFM + ifm = op.args[0].checked_type + assert list(ifm.shape) == list(ifm_shape) + assert str(ifm.dtype) == dtype + assert ifm.shape[3] == ofm_channels + + # check OFM + ofm = op.checked_type + expected_ofm_shape = infra.compute_ofm_shape( + ifm_shape, padding, kernel_shape, strides, dilation + ) + assert list(ofm.shape) == list(expected_ofm_shape) + assert str(ofm.dtype) == dtype + assert ofm.shape[3] == ofm_channels + + # check weights + weights_ohwi = op.args[1].data.asnumpy() + assert str(weights_ohwi.dtype) == dtype + assert weights_ohwi.shape[0] == ofm_channels + assert weights_ohwi.shape[1] == kernel_shape[0] + assert weights_ohwi.shape[2] == kernel_shape[1] + assert weights_ohwi.shape[3] == 1 # only depth multiplier 1 is supported + + # Check that scale_bias matches weight tensor + assert list(op.args[2].checked_type.shape)[0] == ofm_channels + + expected_padding = infra.compute_padding_shape( + ifm_shape, expected_ofm_shape, padding, kernel_shape, strides, dilation + ) + assert list(op.attrs.padding) == list(expected_padding) + assert op.attrs.ofm_channels == ofm_channels + assert list(op.attrs.strides) == list(strides) + assert list(op.attrs.dilation) == list(dilation) + if activation == "RELU": + assert str(op.attrs.activation) == "CLIP" + + depthwise_pattern_table = [ + ( + ethosu.QnnDepthwiseConv2DParams.composite_name, + ethosu.qnn_depthwise_conv2d_pattern(), + lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(), + ) + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], params) + mod = partition_ethosu_by_table(mod, depthwise_pattern_table) + + mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( + legalize.EthosuDepthwiseConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"] + ) + verify(mod["tvmgen_default_ethosu_main_0"]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py new file mode 100644 index 000000000000..b3ce74c4e84a --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +pytest.importorskip("ethosu.vela") + +import tvm +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from .infra import make_ethosu_depthwise_conv2d, get_convolutional_args + + +@pytest.mark.parametrize( + "trial", + [ + [(1, 8, 8, 3), 3, (3, 2), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC"], + [(1, 8, 8, 3), 3, (1, 1), (2, 1), (1, 1), (1, 1), "TANH", "NHWC", "NHWC"], + [(1, 8, 8, 3), 3, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC"], + [(1, 1, 1, 1), 1, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC"], + [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC"], + [(1, 8, 2, 8, 16), 18, (1, 1), (2, 1), (1, 1), (1, 1), "CLIP", "NHCWB16", "NHWC"], + [(1, 7, 9, 40), 40, (3, 2), (1, 2), (2, 1), (1, 2), "CLIP", "NHWC", "NHCWB16"], + [(1, 4, 12, 9, 16), 182, (2, 3), (6, 3), (2, 2), (1, 1), "CLIP", "NHCWB16", "NHCWB16"], + [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (2, 2), "CLIP", "NHWC", "NHWC"], + [(1, 7, 9, 41), 41, (3, 2), (1, 2), (2, 1), (2, 2), "CLIP", "NHWC", "NHCWB16"], + [ + (1, 13, 12, 19, 16), + 182, + (1, 3), + (5, 3), + (2, 1), + (2, 1), + "CLIP", + "NHCWB16", + "NHCWB16", + ], + ], +) +def test_depthwise_conv2d_single(trial): + def _get_func( + ifm_shape, + channels, + kernel_shape, + padding, + strides, + dilation, + activation, + ifm_layout, + ofm_layout, + ): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + depthwise = make_ethosu_depthwise_conv2d( + ifm, + channels, + kernel_shape, + padding, + strides, + dilation, + activation, + ifm_layout, + ofm_layout, + ) + func = relay.Function(relay.analysis.free_vars(depthwise), depthwise) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func(*trial) + mod, _ = lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_convolutional_args(stmt, remove_constants=True)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + ( + ifm_shape, + channels, + kernel_shape, + padding, + strides, + dilation, + activation, + ifm_layout, + ofm_layout, + ) = trial + dilated_kernel_h = (kernel_shape[0] - 1) * dilation[0] + 1 + dilated_kernel_w = (kernel_shape[1] - 1) * dilation[1] + 1 + if ifm_layout == "NHWC": + ifm_stride_c = 1 + ifm_stride_w = ifm_shape[3] + ifm_stride_h = ifm_shape[2] * ifm_shape[3] + ofm_height = (ifm_shape[1] - dilated_kernel_h + padding[0] + padding[0]) // strides[0] + 1 + ofm_width = (ifm_shape[2] - dilated_kernel_w + padding[1] + padding[1]) // strides[1] + 1 + else: + ifm_stride_w = 16 + ifm_stride_c = 16 * ifm_shape[3] + ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] + ofm_height = (ifm_shape[1] - dilated_kernel_h + padding[0] + padding[0]) // strides[0] + 1 + ofm_width = (ifm_shape[3] - dilated_kernel_w + padding[1] + padding[1]) // strides[1] + 1 + + if ofm_layout == "NHWC": + ofm_stride_c = 1 + ofm_stride_w = channels if ofm_width > 1 else 1 + ofm_stride_h = channels * ofm_width if ofm_height > 1 else 1 + else: + ofm_stride_w = 16 + ofm_stride_c = 16 * ofm_width + ofm_stride_h = 16 * ofm_width * ((channels - 1) // 16 + 1) + + answer = [ + "int8", + ifm_shape[1], + ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + channels, + ifm_shape[1], + 0, + ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + 0, + 0, + 0, + 0, + 0.6, + 11, + ifm_layout, + ifm_stride_h, + ifm_stride_w, + ifm_stride_c, + "int8", + ofm_height, + ofm_width, + channels, + ofm_height, + 0, + ofm_width, + 0, + 0, + 0, + 0, + 0.26, + 15, + ofm_layout, + ofm_stride_h, + ofm_stride_w, + ofm_stride_c, + kernel_shape[1], + kernel_shape[0], + strides[1], + strides[0], + dilation[1], + dilation[0], + 13, + padding[0], + padding[1], + padding[0], + padding[1], + activation, + 15 if activation == "CLIP" else 0, + 105 if activation == "CLIP" else 0, + "NONE", + ] + assert data[0] == answer, data[0] diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index b07f3a5016fa..8240b392a1cf 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -497,6 +497,81 @@ def populate_ethosu_conv2d_calls(stmt): assert w_zero_point == ref["w_zero_point"] +# fmt: off +"""A ethosu_depthwise_conv2d tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuDepthwiseConv2D: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_depthwise_conv2d: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_1, [3, 3, 2, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_2, [3, 10], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_depthwise_conv2d_1 = T.match_buffer(ethosu_depthwise_conv2d, [1, 6, 7, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 8, 3, 8, 0, 8, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.6), 11, "NHWC", 24, 3, 1, "int8", 6, 7, 3, 6, 0, 7, T.load("int8", ethosu_depthwise_conv2d_1.data, 0), 0, 0, 0, T.float32(0.26), 15, "NHWC", 21, 3, 1, 2, 3, 1, 1, 1, 1, T.load("int8", placeholder_4.data, 0), 18, 13, T.load("uint8", placeholder_5.data, 0), 30, 0, 0, 0, 0, "CLIP", 15, 105, "NONE", dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +def test_translate_ethosu_depthwise_conv2d(): + def extract_ethosu_depthwise_conv2d_extern_call(mod): + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_depthwise_conv2d_calls = list() + + def populate_ethosu_depthwise_conv2d_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_depthwise_conv2d" + ): + ethosu_depthwise_conv2d_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_depthwise_conv2d_calls) + return ethosu_depthwise_conv2d_calls[0] + + depthwise_conv2d_call = extract_ethosu_depthwise_conv2d_extern_call(SingleEthosuDepthwiseConv2D) + npu_op, w_zero_point = tir_to_cs_translator.translate_ethosu_depthwise_conv2d( + depthwise_conv2d_call + ) + + assert npu_op.ifm.data_type == vapi.NpuDataType.INT8 + assert npu_op.ifm.shape == vapi.NpuShape3D(8, 8, 3) + assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]).height_0 + assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]).height_1 + assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]).width_0 + assert npu_op.ifm.quantization == pytest.approx(vapi.NpuQuantization(0.6, 11)) + assert npu_op.ifm.layout == vapi.NpuLayout.NHWC + assert npu_op.ifm.strides == vapi.NpuShape3D(24, 3, 1) + # Compare OFM + assert npu_op.ofm.data_type == vapi.NpuDataType.INT8 + assert npu_op.ofm.shape == vapi.NpuShape3D(6, 7, 3) + assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(6, 0, 8, [0, 0, 0, 0]).height_0 + assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(6, 0, 7, [0, 0, 0, 0]).height_1 + assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(6, 0, 7, [0, 0, 0, 0]).width_0 + assert npu_op.ofm.quantization == pytest.approx(vapi.NpuQuantization(0.26, 15)) + assert npu_op.ofm.layout == vapi.NpuLayout.NHWC + assert npu_op.ofm.strides == vapi.NpuShape3D(21, 3, 1) + # Compare kernel and padding + assert ( + npu_op.kernel.__dict__ + == vapi.NpuKernel(w=2, h=3, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1).__dict__ + ) + assert npu_op.padding == vapi.NpuPadding(top=0, left=0, bottom=0, right=0) + # Compare activation + assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + assert npu_op.activation.min == 0 + assert npu_op.activation.max == pytest.approx(23.4) + # Compare ifm upscaling + assert npu_op.ifm_upscale == vapi.NpuResamplingMode.NONE + # Compare weight quantization parameters + assert w_zero_point == 13 + + def test_translate_ethosu_copy(): def extract_ethosu_copy_extern_calls(mod): """This function will obtain all ethosu_conv2d diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py b/tests/python/contrib/test_ethosu/test_type_inference.py new file mode 100644 index 000000000000..47fddad773b2 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_type_inference.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +from tvm import relay +from tvm.relay.testing import run_opt_pass +from .infra import make_ethosu_conv2d +from .infra import make_ethosu_depthwise_conv2d + + +@pytest.mark.parametrize( + ["ifm_shape", "ifm_layout"], [((1, 56, 72, 55), "NHWC"), ((1, 56, 4, 72, 16), "NHCWB16")] +) +@pytest.mark.parametrize( + "ofm_shape,ofm_layout", [((1, 54, 38, 122), "NHWC"), ((1, 54, 8, 38, 16), "NHCWB16")] +) +def test_ethosu_conv2d_type_inference( + ifm_shape, + ifm_layout, + ofm_shape, + ofm_layout, +): + ifm_channels = 55 + ofm_channels = 122 + kernel_shape = (3, 2) + padding = (0, 1, 2, 3) + strides = (1, 2) + dilation = (2, 1) + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + conv2d = make_ethosu_conv2d( + ifm, + ifm_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + f = relay.Function([ifm], conv2d) + f = run_opt_pass(f, relay.transform.InferType()) + assert tuple(f.body.checked_type.shape) == ofm_shape + + +@pytest.mark.parametrize( + "ifm_shape, ifm_layout", [((1, 46, 71, 55), "NHWC"), ((1, 46, 4, 71, 16), "NHCWB16")] +) +@pytest.mark.parametrize( + "ofm_shape, ofm_layout", [((1, 44, 37, 55), "NHWC"), ((1, 44, 4, 37, 16), "NHCWB16")] +) +def test_ethosu_depthwise_conv2d_type_inference( + ifm_shape, + ifm_layout, + ofm_shape, + ofm_layout, +): + channels = 55 + kernel_shape = (3, 2) + padding = (0, 1, 2, 3) + strides = (1, 2) + dilation = (2, 1) + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + depthwise_conv2d = make_ethosu_depthwise_conv2d( + ifm, + channels, + kernel_shape, + padding, + strides, + dilation, + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + f = relay.Function([ifm], depthwise_conv2d) + f = run_opt_pass(f, relay.transform.InferType()) + assert tuple(f.body.checked_type.shape) == ofm_shape + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 2e4687fb7985..6e57796d1cbf 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -471,7 +471,11 @@ def test_compile_tflite_module_with_external_codegen_ethosu( for name in mlf_package.getnames() if re.match(r"\./codegen/host/src/\D+\d+\.c", name) ] - assert len(c_source_files) == 17 + # The number of c_source_files depends on the number of fused subgraphs that + # get offloaded to the NPU, e.g. conv2d->depthwise_conv2d->conv2d gets offloaded + # as a single subgraph if both of these operators are supported by the NPU. + # Currently there are two source files for CPU execution and two offload graphs + assert len(c_source_files) == 4 @mock.patch("tvm.relay.build") From 5ad2f77403bed9a2bf356cc0d3d785ecc13e6c58 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 12 Oct 2021 01:22:10 +0900 Subject: [PATCH 04/84] [Relay] Gather op dynamic input support (#9240) * support gather op dynamic input * fix shape func and add test * remove constness check * fix shape func output rank * restore check Co-authored-by: masa --- include/tvm/topi/transform.h | 6 ++++-- python/tvm/relay/op/_transform.py | 20 ++++++++++++++++++++ src/relay/op/tensor/transform.cc | 6 ++++-- tests/python/relay/test_any.py | 22 ++++++++++++++++++++++ 4 files changed, 50 insertions(+), 4 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 8d1a49a4cc5f..3df9caf55d5c 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1233,8 +1233,10 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, } ICHECK_GE(axis, 0); ICHECK_LT(axis, ndim_d); - size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); - ICHECK_GE(indices_dim_i, 1); + if (indices->shape[axis].as()) { + size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); + ICHECK_GE(indices_dim_i, 1); + } ICHECK(indices->dtype.is_int()); Array out_shape; diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 0284d2483ce5..76c806905b18 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1174,3 +1174,23 @@ def gather_nd_shape_func(attrs, inputs, _): assert index_rank > 0, "index_rank needs to be specified for dynamic gather_nd" return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))] + + +@script +def _gather_shape(data_shape, indices_shape, axis): + out_shape = output_tensor((data_shape.shape[0],), "int64") + for i in range(data_shape.shape[0]): + if i != axis: + assert ( + data_shape[i] == indices_shape[i] + ), "data and indices size at non-gather axes must be the same" + out_shape[i] = indices_shape[i] + return out_shape + + +@_reg.register_shape_func("gather", False) +def gather_shape_func(attrs, inputs, _): + """ + Shape func for gather operator. + """ + return [_gather_shape(inputs[0], inputs[1], attrs.axis)] diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3781107eeee1..fa5b31a8abef 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3260,8 +3260,10 @@ bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs, oshape.reserve(ndim_data); for (size_t i = 0; i < ndim_data; ++i) { if (i == static_cast(axis)) { - const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]); - ICHECK_GE(*indice_shape_i, 1); + if (indices->shape[i].as()) { + const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]); + ICHECK_GE(*indice_shape_i, 1); + } } else { ICHECK(reporter->AssertEQ(indices->shape[i], data->shape[i])); } diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index decddc1ef0a4..8788faf45866 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -2064,5 +2064,27 @@ def verify_scatter_nd(data_np, indices_np, updates_np, ref_res): verify_scatter_nd(data, indices, updates, out) +@tvm.testing.uses_gpu +def test_gather(): + def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, axis): + x = relay.var("x", relay.TensorType(data_shape, "float32")) + y = relay.var("y", relay.TensorType(indices_shape, "int32")) + z = relay.gather(x, axis, y) + + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + + data_np = np.random.uniform(size=data_shape_np).astype("float32") + indices_np = np.random.randint(low=0, high=2, size=indices_shape_np, dtype="int32") + + ref_res = tvm.topi.testing.gather_python(data_np, axis, indices_np) + check_result([data_np, indices_np], mod, [ref_res]) + + verify_gather((relay.Any(),), (relay.Any(),), (10,), (10,), 0) + verify_gather((2, 2), (2, relay.Any()), (2, 2), (2, 3), 1) + verify_gather((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3), 1) + verify_gather((relay.Any(), relay.Any()), (relay.Any(), relay.Any()), (2, 3), (1, 3), 0) + + if __name__ == "__main__": pytest.main([__file__]) From 95cde0c2c08b605b0ce335df67c07d4c376c4c7a Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 12 Oct 2021 02:46:15 +0900 Subject: [PATCH 05/84] [AlterLayout] Strided slice layout transform fix (disallow NCHW4c -> NCHW etc properly) (#9245) * prohibit propagating through packed to unpacked layout * add test --- src/relay/op/tensor/transform.cc | 16 +--- .../python/relay/test_pass_alter_op_layout.py | 74 +++++++++++++------ 2 files changed, 54 insertions(+), 36 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index fa5b31a8abef..90a0e3150573 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2599,24 +2599,19 @@ InferCorrectLayoutOutput StridedSliceInferCorrectLayout( params->strides = new_strides; layout = new_layout; } - } else { + } else if (old_layout_name.size() < + new_layout_name.size()) { // prohibit transforms such as NCHW4c -> NCHW if (params->axes) { auto axes = params->axes.value(); Array new_axes; - for (size_t i = 0; i < axes.size(); ++i) { auto old_idx = axes[i]; auto new_idx = new_layout.IndexOf(layout[old_idx]); new_axes.push_back(new_idx); const LayoutAxis& axis = layout[old_idx]; - if (!axis.IsPrimal()) { - // original layout that contains splitted axes is not supported - return out_default; - } - + ICHECK(axis.IsPrimal()); auto factor = new_layout.FactorOf(axis); - if (factor == -1) { new_begin.push_back(begin[i]); new_end.push_back(end[i]); @@ -2636,10 +2631,7 @@ InferCorrectLayoutOutput StridedSliceInferCorrectLayout( } else { for (size_t i = 0; i < begin.size(); i++) { const LayoutAxis& axis = layout[i]; - if (!axis.IsPrimal()) { - // original layout that contains splitted axes is not supported - return out_default; - } + ICHECK(axis.IsPrimal()); auto factor = new_layout.FactorOf(axis); if (factor == -1) { new_begin.push_back(IntImm(begin[i]->dtype, begin[i])); diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 3310b6b2ed69..19685b127d86 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -1397,28 +1397,54 @@ def expected(): assert tvm.ir.structural_equal(a, b) +def test_conv2d_strided_slice_packed_to_unpacked(): + """We do not support propagating through packed to unpacked layout""" + x_shape = (1, 1, 1, 1, 4) + w_shape = (9, 1, 3, 3, 4, 4) + + def before(): + x = relay.var("x", shape=x_shape) + weight = relay.var("weight", shape=w_shape) + y = relay.nn.conv2d( + x, + weight, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW4c", + kernel_layout="OIHW4i4o", + ) + y = relay.strided_slice(y, begin=[0, 0], end=[1, -1], strides=[1, 8]) + return relay.Function([x, weight], y) + + def expected(): + x = relay.var("x", shape=x_shape) + weight = relay.var("weight", shape=w_shape) + x_nchw = relay.layout_transform(x, src_layout="NCHW4c", dst_layout="NCHW") + weight_oihw = relay.layout_transform(weight, src_layout="OIHW4i4o", dst_layout="OIHW") + y = relay.nn.conv2d( + x_nchw, + weight_oihw, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NCHW4c") + y = relay.strided_slice(y, begin=[0, 0], end=[1, -1], strides=[1, 8]) + return relay.Function([x, weight], y) + + def alter_conv2d(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs["data_layout"] = "NCHW" + new_attrs["kernel_layout"] = "OIHW" + return relay.nn.conv2d(data, weight, **new_attrs) + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b) + + if __name__ == "__main__": - test_alter_op() - test_alter_return_none() - test_alter_layout() - test_alter_layout_dual_path() - test_alter_layout_lrn() - test_alter_layout_resnet() - test_alter_layout_broadcast_op() - test_alter_layout_broadcast_scalar_op() - test_alter_layout_scalar() - test_alter_layout_concatenate() - test_alter_layout_nchw_upsamping_op() - test_alter_layout_strided_slice() - test_alter_layout_depthwise_conv2d() - test_alter_layout_prelu() - test_alter_layout_pad() - test_alter_layout_pool() - test_alter_layout_sum() - test_alter_layout_nhwc_arm() - test_alter_layout_nhwc_int8_aarch64() - test_alter_op_with_global_var() - test_alter_op_dense() - test_alter_layout_strided_slice_axes_nhwc() - test_not_inplace_modify() - test_alter_op_dense_packed_data() + pytest.main([__file__]) From f08dca89e41359e6216c678fe45d268d8b1b60bf Mon Sep 17 00:00:00 2001 From: Alexey Gladyshev Date: Mon, 11 Oct 2021 21:57:27 +0300 Subject: [PATCH 06/84] [RPC] Fix Server connecting to RPC Tracker through a Proxy (#9210) --- python/tvm/rpc/proxy.py | 2 +- python/tvm/rpc/tracker.py | 5 +-- tests/python/unittest/test_runtime_rpc.py | 44 +++++++++++++++++++++++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index e5ec73db51b9..c3b0056eb591 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -379,7 +379,7 @@ def _update_tracker(self, period_update=False): if need_update_info: keylist = "[" + ",".join(self._key_set) + "]" - cinfo = {"key": "server:proxy" + keylist} + cinfo = {"key": "server:proxy" + keylist, "addr": [None, self._listen_port]} base.sendjson(self._tracker_conn, [TrackerCode.UPDATE_INFO, cinfo]) assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS self._tracker_pending_puts = [] diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index 74c1f7ac07aa..5a576a705e8a 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -337,9 +337,10 @@ def request(self, key, user, priority, callback): def close(self, conn): self._connections.remove(conn) if "key" in conn._info: - key = conn._info["key"].split(":")[1] # 'server:rasp3b' -> 'rasp3b' for value in conn.put_values: - self._scheduler_map[key].remove(value) + _, _, _, key = value + rpc_key = key.split(":")[0] + self._scheduler_map[rpc_key].remove(value) def stop(self): """Safely stop tracker.""" diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 22aea8d1fcea..6e1fc815d66d 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -29,6 +29,7 @@ from tvm import rpc from tvm.contrib import utils, cc from tvm.rpc.tracker import Tracker +from tvm.rpc.proxy import Proxy if __name__ == "__main__": @@ -538,3 +539,46 @@ def test_rpc_tracker_request(): proc2.join() server.terminate() tracker.terminate() + + +@tvm.testing.requires_rpc +def test_rpc_tracker_via_proxy(): + """ + tracker + / \ + Host -- Proxy -- RPC server + """ + + device_key = "test_device" + + tracker_server = Tracker(port=9000, port_end=9100) + proxy_server = Proxy( + host=tracker_server.host, + port=8888, + port_end=8988, + tracker_addr=(tracker_server.host, tracker_server.port), + ) + + server1 = rpc.Server( + host=proxy_server.host, + port=proxy_server.port, + key=device_key, + tracker_addr=(tracker_server.host, tracker_server.port), + is_proxy=True, + ) + server2 = rpc.Server( + host=proxy_server.host, + port=proxy_server.port, + key=device_key, + tracker_addr=(tracker_server.host, tracker_server.port), + is_proxy=True, + ) + + client = rpc.connect_tracker(tracker_server.host, tracker_server.port) + remote1 = client.request(device_key, session_timeout=30) # pylint: disable=unused-variable + remote2 = client.request(device_key, session_timeout=30) # pylint: disable=unused-variable + + server2.terminate() + server1.terminate() + proxy_server.terminate() + tracker_server.terminate() From 01744d1e8229f025e5f5ad481b9a71692d1a01c3 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 11 Oct 2021 17:13:06 -0700 Subject: [PATCH 07/84] Fix typo in error message in CMakeLists.txt (#9251) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c40b0c878905..24f0653b3a78 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -693,7 +693,7 @@ if(USE_CCACHE) # True for AUTO, ON, /path/to/ccache message(STATUS "Found the path to ccache, enabling ccache") set(PATH_TO_CCACHE ccache) else() - message(FATAL_ERROR "Cannot find ccache. Set USE_CCACHE mode to AUTO or OFF to build without ccache. USE_CCACHE=" "${USE_CCACHE") + message(FATAL_ERROR "Cannot find ccache. Set USE_CCACHE mode to AUTO or OFF to build without ccache. USE_CCACHE=" "${USE_CCACHE}") endif(CCACHE_FOUND) else() # /path/to/ccache set(PATH_TO_CCACHE USE_CCACHE) From 8725eb53d464bebd787ae5e02da6cf477db9cfc1 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Mon, 11 Oct 2021 17:13:22 -0700 Subject: [PATCH 08/84] add stage to log (#9249) --- src/te/operation/tensorize.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 447fc501d03b..9e2d3d0e725f 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -337,7 +337,8 @@ void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, } ICHECK(expr_equal(lhs, rhs)) << "Failed to match the compute with TensorIntrin " << intrin->name << "'s declaration " - << " provided= " << lhs << ", intrin= " << rhs; + << " provided= " << lhs << ", intrin= " << rhs + << ", running this stage: " << stage; } } From a7cf3173084fd4a8e5b75c8d4af82a74a8799af4 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Tue, 12 Oct 2021 01:30:23 +0100 Subject: [PATCH 09/84] [TIR][USMP] Add a parallel to serial for loop converter pass (#8469) * [TIR][USMP] Add a parallel to serial for loop converter pass This is an optional pass to convert all parallel for loops in TIR to serial ones for different reasons such as executor does not support parallel launch of for loops (e.g., AoT) or allocating space for parallel for loops might not be desired. * Additionally adding FFI scaffolding for USMP Change-Id: Id5e8ccb90140d2d3ae113b20a3ca152a54497c45 * [TIR][USMP] Add a parallel to serial for loop converter pass * remove unused import Change-Id: I29d5fdec92120418596f9dba1d6630f65620a603 * [TIR][USMP] Add a parallel to serial for loop converter pass *moved the pass to tir namespace Change-Id: I74720ca2f566066b3a4f22f504d8f0f684c99dc2 * [TIR][USMP] Add a parallel to serial for loop converter pass * fixed docstring Change-Id: I73bb9867fe2ed6a86f65666493c5c6e3edf87b49 * [TIR][USMP] Add a parallel to serial for loop converter pass * fixed mypy lint error Change-Id: I226ef27d5536674fbe4b2d2c6ff47b8cb3b41431 --- include/tvm/tir/transform.h | 9 +++ python/tvm/tir/transform/transform.py | 11 +++ .../transforms/convert_for_loops_serial.cc | 75 +++++++++++++++++++ ..._tir_transform_convert_for_loops_serial.py | 62 +++++++++++++++ 4 files changed, 157 insertions(+) create mode 100644 src/tir/transforms/convert_for_loops_serial.cc create mode 100644 tests/python/unittest/test_tir_transform_convert_for_loops_serial.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index e94b966bc0fc..017078bd7bf7 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -463,6 +463,15 @@ TVM_DLL Pass UnifyThreadBinding(); */ TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); +/*! + * \brief This pass is post-scheduling pass to convert all + * Parallel For loops to Serial ones. This is run + * to attain lesser memory and/or executor/backend + * does not support parallel launch of For loops. + * \return The pass. + */ +TVM_DLL Pass ConvertForLoopsToSerial(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index f072f6b38a43..1abba77a801f 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -715,3 +715,14 @@ def MergeDynamicSharedMemoryAllocations(): The result pass """ return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore + + +def ConvertForLoopsToSerial(): + """Convert Parallel For Loops to Serial For Loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ConvertForLoopsToSerial() # type: ignore diff --git a/src/tir/transforms/convert_for_loops_serial.cc b/src/tir/transforms/convert_for_loops_serial.cc new file mode 100644 index 000000000000..d01ae8a45113 --- /dev/null +++ b/src/tir/transforms/convert_for_loops_serial.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/transforms/convert_for_loops_serial.cc + * \brief Convert all for loops to serial for lesser memory consumption + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class ForLoopSerialConverter : public StmtExprMutator { + public: + ForLoopSerialConverter() = default; + Stmt operator()(const PrimFunc& func); + + private: + Stmt VisitStmt_(const ForNode* op) override; +}; + +Stmt ForLoopSerialConverter::VisitStmt_(const ForNode* op) { + if (op->kind == ForKind::kParallel) { + return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body, op->thread_binding, + op->annotations, op->span); + } + return StmtExprMutator::VisitStmt_(op); +} + +Stmt ForLoopSerialConverter::operator()(const PrimFunc& func) { + return this->VisitStmt(func->body); +} + +PrimFunc ConvertForLoopsToSerial(PrimFunc func) { + PrimFuncNode* fptr = func.CopyOnWrite(); + fptr->body = ForLoopSerialConverter()(func); + return func; +} + +namespace transform { + +Pass ConvertForLoopsToSerial() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return ConvertForLoopsToSerial(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ConvertForLoopsToSerial", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ConvertForLoopsToSerial") + .set_body_typed(ConvertForLoopsToSerial); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py new file mode 100644 index 000000000000..272e0d45410f --- /dev/null +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import tir, script +from tvm.script import ty +from tvm.tir import stmt_functor + +# fmt: off +@tvm.script.tir +def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: ty.handle, placeholder_31: ty.handle, placeholder_32: ty.handle, T_cast_8: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) + placeholder_33 = tir.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = tir.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = tir.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = tir.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_3 = tir.allocate([1, 28, 28, 192], "int16", "global") + for i0_i1_fused_3 in tir.parallel(0, 28): + for i2_3, i3_3 in tir.grid(28, 192): + tir.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), tir.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + for ax0_ax1_fused_ax2_fused_3 in tir.parallel(0, 784): + for ax3_2 in tir.serial(0, 16): + Conv2dOutput_3 = tir.allocate([1, 1, 1, 1], "int32", "global") + tir.store(Conv2dOutput_3, 0, 0, True) + for rc_3 in tir.serial(0, 192): + tir.store(Conv2dOutput_3, 0, (tir.load("int32", Conv2dOutput_3, 0) + (tir.cast(tir.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*tir.cast(tir.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) + tir.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), tir.cast(tir.cast(tir.max(tir.min(tir.q_multiply_shift((tir.load("int32", Conv2dOutput_3, 0) + tir.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) +# fmt: on + + +def test_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(): + primfunc = fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2 + mod = tvm.IRModule.from_expr(primfunc) + mod = tvm.tir.transform.ConvertForLoopsToSerial()(mod) + + def verify_serial_loops(stmt): + if isinstance(stmt, tvm.tir.For): + assert stmt.kind == tvm.tir.ForKind.SERIAL + + for _, primfunc in mod.functions.items(): + stmt_functor.post_order_visit(primfunc.body, verify_serial_loops) + + +if __name__ == "__main__": + pytest.main([__file__]) From d1967f2b1a5dd8b9fa6f6aa44b2c557400dee61f Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 12 Oct 2021 18:17:42 +0900 Subject: [PATCH 10/84] [Relay] Improve reduction op layout propagation for packed input (#9253) * wip * fixed packed dim size logic * fixed test * formatting * fix compile warning --- src/relay/op/tensor/reduce.cc | 42 +++++++++++++------ .../python/relay/test_pass_alter_op_layout.py | 22 +++++----- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index c9f14c91c7b1..5001925b7570 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -149,23 +149,41 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, tvm::Array new_r_axes; std::string inferred_in_string = ""; std::string inferred_out_string = ""; - int axis_index = 0; - for (auto iter_var : layout->axes) { - const auto& layout_axis = LayoutAxis::Get(iter_var); + auto push_new_axis = [&](const std::string& layout_dim, int axis) { + if ((old_r_dims.count(layout_dim) && !params->exclude) || + (!old_r_dims.count(layout_dim) && params->exclude)) { + new_r_axes.push_back(tvm::Integer(axis)); + return true; + } + return false; + }; + for (size_t axis_index = 0; axis_index < layout->axes.size(); ++axis_index) { + const auto& layout_axis = LayoutAxis::Get(layout->axes[axis_index]); const std::string& layout_dim = layout_axis.name(); - // Collect only the primal axis. if (layout_axis.IsPrimal()) { - if (old_r_dims.count(layout_dim) && !params->exclude) { - new_r_axes.push_back(tvm::Integer(axis_index)); - } - if (!old_r_dims.count(layout_dim) && params->exclude) { - new_r_axes.push_back(tvm::Integer(axis_index)); - } + push_new_axis(layout_dim, axis_index); + inferred_in_string += layout_dim; if (!old_r_dims.count(layout_dim) || params->keepdims) { inferred_out_string += layout_dim; } - inferred_in_string += layout_dim; - axis_index++; + } else { + // For example, if the original layout is NCHW, the new layout is NCHW8c, and the original + // reduce axes is [1], the new reduce axes become [1, 4]. + auto primal_dim = layout_axis.ToPrimal().name(); + auto packed_dim = std::to_string(layout.FactorOf(layout_axis)) + layout_dim; + inferred_in_string += packed_dim; + if (push_new_axis(primal_dim, axis_index)) { + if (params->exclude) { + // The primal axis is not reduced, so keep the input packed dim. + inferred_out_string += packed_dim; + } else { + // If the primal axis is part of reduce axes in the original layout, the inner dim + // becomes 1 after reduction. + inferred_out_string += "1" + layout_dim; + } + } else { + inferred_out_string += packed_dim; + } } } diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 19685b127d86..ab36f79c6ea7 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -507,12 +507,12 @@ def expected(): bias = relay.layout_transform(bias, src_layout="NHWC", dst_layout="NCHW") bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c") add = relay.add(y, bias) - y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW") - mean = relay.mean(y, axis=1, exclude=True) - var = relay.variance(y, axis=1, exclude=True) + mean = relay.mean(add, axis=[1, 4], exclude=True) + var = relay.variance(add, axis=[1, 4], exclude=True) denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05)) gamma = relay.var("gamma", shape=(16,)) - denom = denom * gamma + denom_c16c = denom * relay.layout_transform(gamma, src_layout="C", dst_layout="C16c") + denom = relay.layout_transform(denom_c16c, src_layout="C16c", dst_layout="C") denom_expand1 = relay.expand_dims(denom, axis=1, num_newaxis=2) denom_expand2 = relay.expand_dims(denom_expand1, axis=0) denom_nchwc16 = relay.layout_transform( @@ -520,7 +520,10 @@ def expected(): ) out = add * denom_nchwc16 beta = relay.var("beta", shape=(16,)) - numerator = (-mean) * denom + beta + numerator_c16c = (-mean) * denom_c16c + relay.layout_transform( + beta, src_layout="C", dst_layout="C16c" + ) + numerator = relay.layout_transform(numerator_c16c, src_layout="C16c", dst_layout="C") numerator_expand1 = relay.expand_dims(numerator, axis=1, num_newaxis=2) numerator_expand2 = relay.expand_dims(numerator_expand1, axis=0) numerator_nchwc16 = relay.layout_transform( @@ -1096,8 +1099,8 @@ def expected_nchw(): y = relay.nn.conv2d( y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c" ) - ret = relay.layout_transform(y, "NCHW16c", "NCHW") - ret = relay.sum(ret, axis=[1], keepdims=True) + ret = relay.sum(y, axis=[1, 4], keepdims=True) + ret = relay.layout_transform(ret, "NCHW1c", "NCHW") y = relay.Function(analysis.free_vars(ret), ret) return y @@ -1126,9 +1129,8 @@ def expected_nhwc(): y = relay.nn.conv2d( y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c" ) - ret = relay.layout_transform(y, "NCHW16c", "NCHW") - ret = relay.sum(ret, axis=[1], keepdims=True) - ret = relay.layout_transform(ret, "NCHW", "NHWC") + ret = relay.sum(y, axis=[1, 4], keepdims=True) + ret = relay.layout_transform(ret, "NCHW1c", "NHWC") y = relay.Function(analysis.free_vars(ret), ret) return y From 0d10973e1148698bf93f34a07557dec4bcff90ed Mon Sep 17 00:00:00 2001 From: lhutton1 <35535092+lhutton1@users.noreply.github.com> Date: Tue, 12 Oct 2021 16:08:21 +0100 Subject: [PATCH 11/84] [microNPU] Enforce bias when pattern matching conv2d (#9244) Currently a conv2d pattern is matched when no bias is present. However, legalization expects a bias to be present, therefore causing an error when this is not the case. For now, enforce a bias when offloading conv2d to the NPU. Change-Id: I7f74b0f2c151f51ddc66ee1c5ebb77534238909b --- python/tvm/relay/op/contrib/ethosu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 4369376b5689..ca417942840d 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -310,7 +310,7 @@ def qnn_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: ).has_attr({"kernel_layout": "HWIO"}) bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant()) req = is_op("qnn.requantize")( - qnn_conv2d | bias_add, is_constant(), is_constant(), is_constant(), is_constant() + bias_add, is_constant(), is_constant(), is_constant(), is_constant() ) clip_or_req = req.optional(is_op("clip")) return clip_or_req From f4922bca287514f7579a49820a1d6b1587d7e6be Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Tue, 12 Oct 2021 16:17:55 +0100 Subject: [PATCH 12/84] Fix USMP parallel to serial loop transform test (#9254) Caused by https://github.com/apache/tvm/pull/8469 being stale on merge when https://github.com/apache/tvm/pull/9115 had changed the namespace for `tvm.script`. --- ..._tir_transform_convert_for_loops_serial.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py index 272e0d45410f..a91fa2591e00 100644 --- a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -17,31 +17,31 @@ import pytest import tvm -from tvm import tir, script -from tvm.script import ty + +from tvm.script import tir as T from tvm.tir import stmt_functor # fmt: off -@tvm.script.tir -def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: ty.handle, placeholder_31: ty.handle, placeholder_32: ty.handle, T_cast_8: ty.handle) -> None: +@T.prim_func +def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) - placeholder_33 = tir.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_34 = tir.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_35 = tir.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_9 = tir.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) + placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_3 = tir.allocate([1, 28, 28, 192], "int16", "global") - for i0_i1_fused_3 in tir.parallel(0, 28): - for i2_3, i3_3 in tir.grid(28, 192): - tir.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), tir.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) - for ax0_ax1_fused_ax2_fused_3 in tir.parallel(0, 784): - for ax3_2 in tir.serial(0, 16): - Conv2dOutput_3 = tir.allocate([1, 1, 1, 1], "int32", "global") - tir.store(Conv2dOutput_3, 0, 0, True) - for rc_3 in tir.serial(0, 192): - tir.store(Conv2dOutput_3, 0, (tir.load("int32", Conv2dOutput_3, 0) + (tir.cast(tir.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*tir.cast(tir.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) - tir.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), tir.cast(tir.cast(tir.max(tir.min(tir.q_multiply_shift((tir.load("int32", Conv2dOutput_3, 0) + tir.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global") + for i0_i1_fused_3 in T.parallel(0, 28): + for i2_3, i3_3 in T.grid(28, 192): + T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), T.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): + for ax3_2 in T.serial(0, 16): + Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global") + T.store(Conv2dOutput_3, 0, 0, True) + for rc_3 in T.serial(0, 192): + T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast(T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*T.cast(T.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) + T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) # fmt: on From 9f27be60ad1c84706ff2bdbe438626465970681a Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Tue, 12 Oct 2021 16:26:59 +0100 Subject: [PATCH 13/84] [TVMC] Split common tvmc test file into more specific files (#9206) The `test_tvmc_common.py` file was becoming a bit of a mixed bag of tests and as we now want to extend the `Target` processing logic it made sense to split each out into its own file to make it clearer what each does. `test_common.py` has also been renamed before we start using it for all the tests instead. --- tests/python/driver/tvmc/test_frontends.py | 128 +++++- tests/python/driver/tvmc/test_pass_config.py | 73 ++++ .../{test_common.py => test_pass_list.py} | 3 +- tests/python/driver/tvmc/test_shape_parser.py | 96 ++++ tests/python/driver/tvmc/test_target.py | 143 ++++++ tests/python/driver/tvmc/test_tracker.py | 49 +++ tests/python/driver/tvmc/test_tvmc_common.py | 413 ------------------ 7 files changed, 489 insertions(+), 416 deletions(-) create mode 100644 tests/python/driver/tvmc/test_pass_config.py rename tests/python/driver/tvmc/{test_common.py => test_pass_list.py} (97%) create mode 100644 tests/python/driver/tvmc/test_shape_parser.py create mode 100644 tests/python/driver/tvmc/test_target.py create mode 100644 tests/python/driver/tvmc/test_tracker.py delete mode 100644 tests/python/driver/tvmc/test_tvmc_common.py diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index 569c42020817..4d2fb56c5d4e 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -14,11 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os -import tarfile import pytest +import tvm from tvm.ir.module import IRModule from tvm.driver import tvmc @@ -229,3 +228,128 @@ def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): model_format="pytorch", shape_dict={"input": [1, 3, 224, 224]}, ) + + +def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + before = tvmc_model.mod + + expected_layout = "NCHW" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NHWC" + and node.attrs.dst_layout == "NCHW" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert any(layout_transform_calls), "Expected 'layout_transform NHWC->NCHW' not found" + + +def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50): + # some CI environments wont offer ONNX, so skip in case it is not present + pytest.importorskip("onnx") + + tvmc_model = tvmc.frontends.load_model(onnx_resnet50) + before = tvmc_model.mod + + expected_layout = "NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NHWC" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" + + +def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50): + # some CI environments wont offer Paddle, so skip in case it is not present + pytest.importorskip("paddle") + + tvmc_model = tvmc.frontends.load_model(paddle_resnet50, "paddle") + before = tvmc_model.mod + + expected_layout = "NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NHWC" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" + + +def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + before = tvmc_model.mod + + expected_layout = "NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NHWC" + and node.attrs.dst_layout == "NHWC" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" + + +def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50): + # some CI environments wont offer ONNX, so skip in case it is not present + pytest.importorskip("onnx") + + tvmc_model = tvmc.frontends.load_model(onnx_resnet50) + before = tvmc_model.mod + + expected_layout = "NCHW" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NCHW" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" diff --git a/tests/python/driver/tvmc/test_pass_config.py b/tests/python/driver/tvmc/test_pass_config.py new file mode 100644 index 000000000000..d8ffd7d4d521 --- /dev/null +++ b/tests/python/driver/tvmc/test_pass_config.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from tvm.contrib.target.vitis_ai import vitis_ai_available +from tvm.driver import tvmc + +from tvm.driver.tvmc.common import TVMCException + + +def test_config_invalid_format(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"]) + + +def test_config_missing_from_tvm(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"]) + + +def test_config_unsupported_tvmc_config(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["tir.LoopPartition=value"]) + + +def test_config_empty(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs([""]) + + +def test_config_valid_config_bool(): + configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"]) + + assert len(configs) == 1 + assert "relay.backend.use_auto_scheduler" in configs.keys() + assert configs["relay.backend.use_auto_scheduler"] == True + + +@pytest.mark.skipif( + not vitis_ai_available(), + reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", +) +def test_config_valid_multiple_configs(): + configs = tvmc.common.parse_configs( + [ + "relay.backend.use_auto_scheduler=false", + "tir.detect_global_barrier=10", + "relay.ext.vitis_ai.options.build_dir=mystring", + ] + ) + + assert len(configs) == 3 + assert "relay.backend.use_auto_scheduler" in configs.keys() + assert configs["relay.backend.use_auto_scheduler"] == False + assert "tir.detect_global_barrier" in configs.keys() + assert configs["tir.detect_global_barrier"] == 10 + assert "relay.ext.vitis_ai.options.build_dir" in configs.keys() + assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring" diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_pass_list.py similarity index 97% rename from tests/python/driver/tvmc/test_common.py rename to tests/python/driver/tvmc/test_pass_list.py index 5cac6a1378a5..de50b04f415a 100644 --- a/tests/python/driver/tvmc/test_common.py +++ b/tests/python/driver/tvmc/test_pass_list.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import argparse import pytest from tvm.driver import tvmc -def test_common_parse_pass_list_str(): +def test_parse_pass_list_str(): assert [""] == tvmc.common.parse_pass_list_str("") assert ["FoldScaleAxis", "FuseOps"] == tvmc.common.parse_pass_list_str("FoldScaleAxis,FuseOps") diff --git a/tests/python/driver/tvmc/test_shape_parser.py b/tests/python/driver/tvmc/test_shape_parser.py new file mode 100644 index 000000000000..c021078630ed --- /dev/null +++ b/tests/python/driver/tvmc/test_shape_parser.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse + +import pytest + +from tvm.driver import tvmc + + +def test_shape_parser(): + # Check that a valid input is parsed correctly + shape_string = "input:[10,10,10]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10]} + + +def test_alternate_syntax(): + shape_string = "input:0:[10,10,10] input2:[20,20,20,20]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]} + + +@pytest.mark.parametrize( + "shape_string", + [ + "input:[10,10,10] input2:[20,20,20,20]", + "input: [10, 10, 10] input2: [20, 20, 20, 20]", + "input:[10,10,10],input2:[20,20,20,20]", + ], +) +def test_alternate_syntaxes(shape_string): + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} + + +def test_negative_dimensions(): + # Check that negative dimensions parse to Any correctly. + shape_string = "input:[-1,3,224,224]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + # Convert to strings to allow comparison with Any. + assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" + + +def test_multiple_valid_gpu_inputs(): + # Check that multiple valid gpu inputs are parsed correctly. + shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}" + assert str(shape_dict) == expected + + +def test_invalid_pattern(): + shape_string = "input:[a,10]" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + + +def test_invalid_separators(): + shape_string = "input:5,10 input2:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + + +def test_invalid_colon(): + shape_string = "gpu_0/data_0:5,10 :test:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + + +@pytest.mark.parametrize( + "shape_string", + [ + "gpu_0/data_0:5,10 /:10,10", + "gpu_0/data_0:5,10 data/:10,10", + "gpu_0/data_0:5,10 /data:10,10", + "gpu_0/invalid/data_0:5,10 data_1:10,10", + ], +) +def test_invalid_slashes(shape_string): + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) diff --git a/tests/python/driver/tvmc/test_target.py b/tests/python/driver/tvmc/test_target.py new file mode 100644 index 000000000000..afb099f3add6 --- /dev/null +++ b/tests/python/driver/tvmc/test_target.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from tvm.driver import tvmc + +from tvm.driver.tvmc.common import TVMCException + + +def test_target_from_cli__error_duplicate(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("llvm, llvm") + + +def test_target_invalid_more_than_two_tvm_targets(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("cuda, opencl, llvm") + + +def test_target_from_cli__error_target_not_found(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("invalidtarget") + + +def test_target_from_cli__error_no_tvm_target(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("ethos-n77") + + +def test_target_two_tvm_targets(): + tvm_target, extra_targets = tvmc.common.target_from_cli( + "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu" + ) + + assert "opencl" in str(tvm_target) + assert "llvm" in str(tvm_target.host) + + # No extra targets + assert 0 == len(extra_targets) + + +def test_tokenize_target_with_opts(): + tokens = tvmc.common.tokenize_target("foo -opt1=value1 --flag, bar -opt2=value2") + expected_tokens = ["foo", "-opt1=value1", "--flag", ",", "bar", "-opt2=value2"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_plus_sign(): + tokens = tvmc.common.tokenize_target("foo -opt1=+value1 --flag, bar -opt2=test,+v") + expected_tokens = ["foo", "-opt1=+value1", "--flag", ",", "bar", "-opt2=test,+v"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_commas(): + tokens = tvmc.common.tokenize_target("foo -opt1=v,a,l,u,e,1 --flag") + expected_tokens = ["foo", "-opt1=v,a,l,u,e,1", "--flag"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_commas_and_single_quotes(): + tokens = tvmc.common.tokenize_target("foo -opt1='v, a, l, u, e', bar") + expected_tokens = ["foo", "-opt1='v, a, l, u, e'", ",", "bar"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_commas_and_double_quotes(): + tokens = tvmc.common.tokenize_target('foo -opt1="v, a, l, u, e", bar') + expected_tokens = ["foo", '-opt1="v, a, l, u, e"', ",", "bar"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_dashes(): + tokens = tvmc.common.tokenize_target("foo-bar1 -opt-1=t-e-s-t, baz") + expected_tokens = ["foo-bar1", "-opt-1=t-e-s-t", ",", "baz"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_parse_single_target_with_opts(): + targets = tvmc.common.parse_target("llvm -device=arm_cpu --system-lib") + + assert len(targets) == 1 + assert "device" in targets[0]["opts"] + assert "system-lib" in targets[0]["opts"] + + +def test_parse_multiple_target(): + targets = tvmc.common.parse_target("compute-library, llvm -device=arm_cpu --system-lib") + + assert len(targets) == 2 + assert "compute-library" == targets[0]["name"] + assert "llvm" == targets[1]["name"] + + +def test_parse_multiple_target_with_opts(): + targets = tvmc.common.parse_target("ethos-n77 -myopt=value, llvm -device=arm_cpu --system-lib") + + assert len(targets) == 2 + assert "ethos-n77" == targets[0]["name"] + assert "myopt" in targets[0]["opts"] + assert "value" == targets[0]["opts"]["myopt"] + assert "llvm" == targets[1]["name"] + + +def test_parse_quotes_and_separators_on_options(): + targets_no_quote = tvmc.common.parse_target("foo -option1=+v1.0x,+value,+bar") + targets_single_quote = tvmc.common.parse_target("foo -option1='+v1.0x,+value'") + targets_double_quote = tvmc.common.parse_target('foo -option1="+v1.0x,+value"') + + assert len(targets_no_quote) == 1 + assert "+v1.0x,+value,+bar" == targets_no_quote[0]["opts"]["option1"] + + assert len(targets_single_quote) == 1 + assert "+v1.0x,+value" == targets_single_quote[0]["opts"]["option1"] + + assert len(targets_double_quote) == 1 + assert "+v1.0x,+value" == targets_double_quote[0]["opts"]["option1"] diff --git a/tests/python/driver/tvmc/test_tracker.py b/tests/python/driver/tvmc/test_tracker.py new file mode 100644 index 000000000000..2ca0fae8f45e --- /dev/null +++ b/tests/python/driver/tvmc/test_tracker.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tvm.driver import tvmc + + +def test_tracker_host_port_from_cli__hostname_port(): + input_str = "1.2.3.4:9090" + expected_host = "1.2.3.4" + expected_port = 9090 + + actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + + assert expected_host == actual_host + assert expected_port == actual_port + + +def test_tracker_host_port_from_cli__hostname_port__empty(): + input_str = "" + + actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + + assert actual_host is None + assert actual_port is None + + +def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090(): + input_str = "1.2.3.4" + expected_host = "1.2.3.4" + expected_port = 9090 + + actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + + assert expected_host == actual_host + assert expected_port == actual_port diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py deleted file mode 100644 index bdfdb48ce6a0..000000000000 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ /dev/null @@ -1,413 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import argparse - -import pytest - -import tvm -from tvm.contrib.target.vitis_ai import vitis_ai_available -from tvm.driver import tvmc - -from tvm.driver.tvmc.common import TVMCException - - -def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant): - # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip("tflite") - - tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) - before = tvmc_model.mod - - expected_layout = "NCHW" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NHWC" - and node.attrs.dst_layout == "NCHW" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert any(layout_transform_calls), "Expected 'layout_transform NHWC->NCHW' not found" - - -def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50): - # some CI environments wont offer ONNX, so skip in case it is not present - pytest.importorskip("onnx") - - tvmc_model = tvmc.frontends.load_model(onnx_resnet50) - before = tvmc_model.mod - - expected_layout = "NHWC" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NCHW" - and node.attrs.dst_layout == "NHWC" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" - - -def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50): - # some CI environments wont offer Paddle, so skip in case it is not present - pytest.importorskip("paddle") - - tvmc_model = tvmc.frontends.load_model(paddle_resnet50, "paddle") - before = tvmc_model.mod - - expected_layout = "NHWC" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NCHW" - and node.attrs.dst_layout == "NHWC" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" - - -def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant): - # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip("tflite") - - tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) - before = tvmc_model.mod - - expected_layout = "NHWC" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NHWC" - and node.attrs.dst_layout == "NHWC" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" - - -def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50): - # some CI environments wont offer ONNX, so skip in case it is not present - pytest.importorskip("onnx") - - tvmc_model = tvmc.frontends.load_model(onnx_resnet50) - before = tvmc_model.mod - - expected_layout = "NCHW" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NCHW" - and node.attrs.dst_layout == "NCHW" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" - - -def test_tracker_host_port_from_cli__hostname_port(): - input_str = "1.2.3.4:9090" - expected_host = "1.2.3.4" - expected_port = 9090 - - actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) - - assert expected_host == actual_host - assert expected_port == actual_port - - -def test_tracker_host_port_from_cli__hostname_port__empty(): - input_str = "" - - actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) - - assert actual_host is None - assert actual_port is None - - -def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090(): - input_str = "1.2.3.4" - expected_host = "1.2.3.4" - expected_port = 9090 - - actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) - - assert expected_host == actual_host - assert expected_port == actual_port - - -def test_shape_parser(): - # Check that a valid input is parsed correctly - shape_string = "input:[10,10,10]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input": [10, 10, 10]} - # Check that multiple valid input shapes are parse correctly - shape_string = "input:[10,10,10] input2:[20,20,20,20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} - # Check that multiple valid input shapes with colons are parse correctly - shape_string = "input:0:[10,10,10] input2:[20,20,20,20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]} - # Check that alternate syntax parses correctly - shape_string = "input: [10, 10, 10] input2: [20, 20, 20, 20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} - shape_string = "input:[10,10,10],input2:[20,20,20,20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} - # Check that negative dimensions parse to Any correctly. - shape_string = "input:[-1,3,224,224]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - # Convert to strings to allow comparison with Any. - assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" - # Check that multiple valid gpu inputs are parsed correctly. - shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}" - assert str(shape_dict) == expected - - # Check that invalid pattern raises expected error. - shape_string = "input:[a,10]" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with invalid separators raises error. - shape_string = "input:5,10 input2:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with a invalid slash raises error. - shape_string = "gpu_0/data_0:5,10 /:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with a invalid colon raises error. - shape_string = "gpu_0/data_0:5,10 :test:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with a invalid slash raises error. - shape_string = "gpu_0/data_0:5,10 data/:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with a invalid slash raises error. - shape_string = "gpu_0/data_0:5,10 /data:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with invalid slashes raises error. - shape_string = "gpu_0/invalid/data_0:5,10 data_1:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - - -def test_target_from_cli__error_duplicate(): - with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("llvm, llvm") - - -def test_target_invalid_more_than_two_tvm_targets(): - with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("cuda, opencl, llvm") - - -def test_target_from_cli__error_target_not_found(): - with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("invalidtarget") - - -def test_target_from_cli__error_no_tvm_target(): - with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("ethos-n77") - - -def test_target_two_tvm_targets(): - tvm_target, extra_targets = tvmc.common.target_from_cli( - "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu" - ) - - assert "opencl" in str(tvm_target) - assert "llvm" in str(tvm_target.host) - - # No extra targets - assert 0 == len(extra_targets) - - -def test_tokenize_target_with_opts(): - tokens = tvmc.common.tokenize_target("foo -opt1=value1 --flag, bar -opt2=value2") - expected_tokens = ["foo", "-opt1=value1", "--flag", ",", "bar", "-opt2=value2"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_plus_sign(): - tokens = tvmc.common.tokenize_target("foo -opt1=+value1 --flag, bar -opt2=test,+v") - expected_tokens = ["foo", "-opt1=+value1", "--flag", ",", "bar", "-opt2=test,+v"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_commas(): - tokens = tvmc.common.tokenize_target("foo -opt1=v,a,l,u,e,1 --flag") - expected_tokens = ["foo", "-opt1=v,a,l,u,e,1", "--flag"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_commas_and_single_quotes(): - tokens = tvmc.common.tokenize_target("foo -opt1='v, a, l, u, e', bar") - expected_tokens = ["foo", "-opt1='v, a, l, u, e'", ",", "bar"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_commas_and_double_quotes(): - tokens = tvmc.common.tokenize_target('foo -opt1="v, a, l, u, e", bar') - expected_tokens = ["foo", '-opt1="v, a, l, u, e"', ",", "bar"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_dashes(): - tokens = tvmc.common.tokenize_target("foo-bar1 -opt-1=t-e-s-t, baz") - expected_tokens = ["foo-bar1", "-opt-1=t-e-s-t", ",", "baz"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_parse_single_target_with_opts(): - targets = tvmc.common.parse_target("llvm -device=arm_cpu --system-lib") - - assert len(targets) == 1 - assert "device" in targets[0]["opts"] - assert "system-lib" in targets[0]["opts"] - - -def test_parse_multiple_target(): - targets = tvmc.common.parse_target("compute-library, llvm -device=arm_cpu --system-lib") - - assert len(targets) == 2 - assert "compute-library" == targets[0]["name"] - assert "llvm" == targets[1]["name"] - - -def test_parse_multiple_target_with_opts(): - targets = tvmc.common.parse_target("ethos-n77 -myopt=value, llvm -device=arm_cpu --system-lib") - - assert len(targets) == 2 - assert "ethos-n77" == targets[0]["name"] - assert "myopt" in targets[0]["opts"] - assert "value" == targets[0]["opts"]["myopt"] - assert "llvm" == targets[1]["name"] - - -def test_parse_quotes_and_separators_on_options(): - targets_no_quote = tvmc.common.parse_target("foo -option1=+v1.0x,+value,+bar") - targets_single_quote = tvmc.common.parse_target("foo -option1='+v1.0x,+value'") - targets_double_quote = tvmc.common.parse_target('foo -option1="+v1.0x,+value"') - - assert len(targets_no_quote) == 1 - assert "+v1.0x,+value,+bar" == targets_no_quote[0]["opts"]["option1"] - - assert len(targets_single_quote) == 1 - assert "+v1.0x,+value" == targets_single_quote[0]["opts"]["option1"] - - assert len(targets_double_quote) == 1 - assert "+v1.0x,+value" == targets_double_quote[0]["opts"]["option1"] - - -def test_config_invalid_format(): - with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"]) - - -def test_config_missing_from_tvm(): - with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"]) - - -def test_config_unsupported_tvmc_config(): - with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs(["tir.LoopPartition=value"]) - - -def test_config_empty(): - with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs([""]) - - -def test_config_valid_config_bool(): - configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"]) - - assert len(configs) == 1 - assert "relay.backend.use_auto_scheduler" in configs.keys() - assert configs["relay.backend.use_auto_scheduler"] == True - - -@pytest.mark.skipif( - not vitis_ai_available(), - reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", -) -def test_config_valid_multiple_configs(): - configs = tvmc.common.parse_configs( - [ - "relay.backend.use_auto_scheduler=false", - "tir.detect_global_barrier=10", - "relay.ext.vitis_ai.options.build_dir=mystring", - ] - ) - - assert len(configs) == 3 - assert "relay.backend.use_auto_scheduler" in configs.keys() - assert configs["relay.backend.use_auto_scheduler"] == False - assert "tir.detect_global_barrier" in configs.keys() - assert configs["tir.detect_global_barrier"] == 10 - assert "relay.ext.vitis_ai.options.build_dir" in configs.keys() - assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring" From 4f6b478336042583d369cf7489baa671cde25f3f Mon Sep 17 00:00:00 2001 From: mbaret <55580676+mbaret@users.noreply.github.com> Date: Tue, 12 Oct 2021 21:50:34 +0100 Subject: [PATCH 14/84] Address review comments on Arm(R) Ethos(TM)-U PR 3/6 (#9159) * Address review comments on Arm(R) Ethos(TM)-U PR 3/6 Change-Id: I22961885a503be31f6a72622ae0b5f874cc6f463 * Fix rebasing error Change-Id: I3e2fde786096ea331fcb366080fa779ec4ea4a5d * Fix more rebasing problems Change-Id: I1026e3ccee33a3fdec9ebbf6456bae244ad4f1d5 --- .../backend/contrib/ethosu/tir/compiler.py | 20 +- .../backend/contrib/ethosu/tir/convolution.py | 2 +- .../relay/backend/contrib/ethosu/tir/dma.py | 2 +- .../backend/contrib/ethosu/tir/passes.py | 6 +- .../backend/contrib/ethosu/tir/scheduler.py | 36 +-- .../backend/contrib/ethosu/tir/transform.py | 2 +- .../relay/backend/contrib/ethosu/tir/utils.py | 2 +- .../contrib/ethosu/tir_to_cs_translator.py | 164 ++++++------ .../relay/backend/contrib/ethosu/vela_api.py | 50 ++-- .../backend/contrib/ethosu/to_te_graph.cc | 234 ------------------ src/relay/backend/te_compiler_cache.cc | 40 +-- .../contrib/test_ethosu/test_attr_passing.py | 8 +- .../test_ethosu/test_encode_constants.py | 16 +- .../contrib/test_ethosu/test_scheduler.py | 14 +- .../contrib/test_ethosu/test_vela_api.py | 5 +- 15 files changed, 197 insertions(+), 404 deletions(-) delete mode 100644 src/relay/backend/contrib/ethosu/to_te_graph.cc diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index c59a386fefbb..3283e0515c72 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""The integration of Arm(R) Ethos(TM)-U NPU TIR compiler""" +"""The integration of the Arm(R) Ethos(TM)-U NPU TIR compiler.""" import tvm from tvm import relay from tvm.relay.expr_functor import ExprMutator @@ -29,7 +29,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target. The resulting TIR module will contain a single function - that comprises of a sequence of tir.extern_calls to NPU + that consists of a sequence of tir.extern_calls to NPU operations. Parameters @@ -96,20 +96,20 @@ def lower_ethosu(sch, args, const_dict, name="main"): def lower_to_te(prim_func): - """Lower a Relay primitive function to a Tensor Expression graph. + """Lower a Relay primitive function to a Tensor Expression in an unscheduled CachedFunc. Parameters ---------- prim_func : tvm.relay.Function - The Relay function to lowerethosu_runtime([]). + The Relay function to lower. Returns ------- - out : TEGraph - The lowered Tensor Expression graph. + out : CachedFunc + The lowered Tensor Expression as part of a CachedFunc. """ - f = tvm._ffi.get_global_func("relay.backend.contrib.ethosu.LowerToTE") + f = tvm._ffi.get_global_func("relay.backend.LowerToTE") return f(prim_func) @@ -193,7 +193,7 @@ def lower_to_tir(func, cascader=None): func, consts = extract_constants(func) mod = tvm.IRModule.from_expr(func) func = relay.transform.InferType()(mod)["main"] - te_graph = lower_to_te(func) - s = schedule(te_graph, consts, cascader) - mod, consts = lower_ethosu(s, te_graph, consts) + cached_func = lower_to_te(func) + s = schedule(cached_func, consts, cascader) + mod, consts = lower_ethosu(s, cached_func, consts) return mod, consts diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 33fbdcd2b24f..fd7fa293ccfb 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Extract information from the convolution operators in TIR.""" +"""Extract parameters from the convolution operators in TIR.""" import tvm from ..vela_api import SCALE_BIAS_LENGTH from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py index ecd402d63309..a116e51c5b7c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Extract information from the DMA operators in TIR.""" +"""Extract parameters from the DMA operators in TIR.""" import tvm from .utils import get_outer_loops, get_base_address, get_strides, get_op_attrs from .spec import SerialFeatureMap, SerialPadding diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 8bb410e986c7..761c8aad7bb1 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler""" +"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler.""" import numpy as np # type: ignore import tvm @@ -301,7 +301,7 @@ def EncodeConstants(const_dict): pointer_to_buffer = {} rewrite_buffer = {} rewrite_pointer = {} - accel_type = vela_api.get_target_accel_type() # type: ignore + accel_config = vela_api.get_accelerator_config() def _align_scale_bias(tir_extern_call, bias): """Align the scale_bias to 16 bytes.""" @@ -316,7 +316,7 @@ def _align_scale_bias(tir_extern_call, bias): def _encode_weights(tir_extern_call, weights): """Encode the weights for a TIR extern call.""" - value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_type) + value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_config) value = np.frombuffer(value_bytes, dtype="uint8") return value diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index 5d9027bf2078..7f892d0c602a 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -15,17 +15,17 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Different schedulers for Arm(R) Ethos(TM)-U NPU""" +"""Scheduling for Arm(R) Ethos(TM)-U NPU.""" import tvm -def schedule(te_graph, const_dict, cascader=None): - """Schedule a TE graph for NPU compilation. +def schedule(cached_func, const_dict, cascader=None): + """Schedule a CachedFunc for NPU compilation. Parameters ---------- - te_graph - The TE graph to schedule. + cached_func : CachedFunc + The CachedFunc to schedule. const_dict : dict of int to numpy.ndarray The constant dictionary. cascader : callable, optional @@ -38,10 +38,10 @@ def schedule(te_graph, const_dict, cascader=None): The completed schedule for the graph. """ - s = tvm.te.create_schedule([t.op for t in te_graph.outputs]) + s = tvm.te.create_schedule([t.op for t in cached_func.outputs]) if cascader: - cascader(te_graph, const_dict, s) - inline_no_ops(te_graph, s) + cascader(cached_func, const_dict, s) + inline_no_ops(cached_func, s) schedule_pragmas(s) schedule_cache_reads(s) return s @@ -96,7 +96,7 @@ def total_cascader(stripe_size): """ - def _cascader(te_graph, const_dict, sch): + def _cascader(cached_func, const_dict, sch): scheduled = set() def _visit(tensor, stage, ax): @@ -106,8 +106,8 @@ def _visit(tensor, stage, ax): for input_tensor in tensor.op.input_tensors: _visit(input_tensor, stage, ax) - assert len(te_graph.outputs) == 1 - out = te_graph.outputs[0] + assert len(cached_func.outputs) == 1 + out = cached_func.outputs[0] oi, _ = tile_nd(sch, out, stripe_size) for ax in oi: sch[out].unroll(ax) @@ -126,14 +126,14 @@ def copy_constants(): The planning function. """ - def _planner(te_graph, const_dict, sch): + def _planner(cached_func, const_dict, sch): planned = set() # type: ignore def _visit(tensor, reader): if tensor is not planned: planned.add(tensor) if isinstance(tensor.op, tvm.te.PlaceholderOp): - index = list(te_graph.inputs).index(tensor) + index = list(cached_func.inputs).index(tensor) if index in const_dict: sch.cache_read(tensor, "global", [reader]) @@ -141,7 +141,7 @@ def _visit(tensor, reader): for input_tensor in tensor.op.input_tensors: _visit(input_tensor, tensor) - for output_tensor in te_graph.outputs: + for output_tensor in cached_func.outputs: _visit(output_tensor, None) return _planner @@ -216,7 +216,7 @@ def _detect_cache_read(stage): stage.pragma(fax, "op", "ethosu_copy") -def inline_no_ops(te_graph, sch): +def inline_no_ops(cached_func, sch): """Inline 'no-ops' - operations that in principle do nothing. Modifies the schedule in-place. For now we inline reshape and @@ -224,8 +224,8 @@ def inline_no_ops(te_graph, sch): Parameters ---------- - te_graph - The TE graph. + cached_func : CachedFunc + The cached func. sch : tvm.te.Schedule The schedule. @@ -241,7 +241,7 @@ def _visit(tensor): for input_tensor in tensor.op.input_tensors: _visit(input_tensor) - for out in te_graph.outputs: + for out in cached_func.outputs: _visit(out) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py index 0403ce2c7e8f..f50975c83838 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Extract information from the transform operators in TIR.""" +"""Extract parameters from the transform operators in TIR.""" import tvm from .spec import SerialCopy from .utils import get_base_address, get_op_attrs diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py index ccfc2dfbfc48..de1c0ab19f6e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name -"""Helper utility functions used by the TIR compiler""" +"""Helper utility functions used by the NPU TIR compiler""" import tvm from tvm import arith diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 408eab6427ca..bcae01a10214 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -18,7 +18,7 @@ the Relay to TIR compilation process, to Vela API calls to generate command stream. """ -from typing import NamedTuple +from typing import Dict, NamedTuple, Tuple, Union from enum import auto from enum import Enum import numpy as np # type: ignore @@ -32,7 +32,7 @@ class BufferType(Enum): - """The buffer types the codegen supports""" + """The type of information that a buffer contains.""" constant = auto() input_or_output = auto() @@ -50,7 +50,7 @@ class BufferType(Enum): class BufferInfo(NamedTuple): - """A data structure to hold metadata of the buffer""" + """A data structure to hold metadata of the buffer.""" # If the buffer holds constants, the values will contain that otherwise None values: np.ndarray @@ -90,9 +90,9 @@ def translate(tir_module, params): for extern_call in extern_calls: _npu_ops.append(translate_ethosu_tir_extern_call(extern_call)) _npu_ops, constant_tensor, scratch_size = assign_addresses(buffer_info, _npu_ops) - target_accel_type = vela_api.get_target_accel_type() - cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_type) - payload = vapi.npu_create_driver_payload(cmds, target_accel_type) + target_accel_config = vela_api.get_accelerator_config() + cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_config) + payload = vapi.npu_create_driver_payload(cmds, target_accel_config) hex_value = "" if constant_tensor is None else constant_tensor.tobytes().hex() return payload.hex(), hex_value, scratch_size @@ -125,9 +125,10 @@ def populate_extern_calls(stmt): return extern_calls -def extract_buffer_info(mod, param_dict): - """ - This function is to read the tvm.IRModule that +def extract_buffer_info( + mod: tvm.IRModule, param_dict: Dict[int, np.ndarray] +) -> Dict[str, BufferInfo]: + """This function is to read the tvm.IRModule that contains Relay to TIR compiled IRModule. Thereafter, this will extract the buffer information as the shape and constant data (if any). @@ -136,12 +137,14 @@ def extract_buffer_info(mod, param_dict): ---------- mod : tvm.IRModule The NPU TIR IRModule. - param_dict : dict + param_dict : Dict[int, np.ndarray] A dictionary containing param idx --> const numpy.NDArray + Returns ------- - dict - a dictionary of buffer names --> BufferInfo + dict : Dict[str, BufferInfo] + A dictionary of buffer names --> BufferInfo + """ buffer_info = dict() # There should only be a single function @@ -328,14 +331,15 @@ def translate_ethosu_copy(tir_extern_call): return _create_npu_dma_op(serial_object) -def _convert_clip_bounds(npu_op): - """ - This function will convert the min and max value +def _convert_clip_bounds(npu_op: vapi.NpuBlockOperation): + """This function will convert the min and max value of clip activations to non quantized floats as expected by the API. + Parameters ---------- - npu_op : ethosu.vela.api.NpuBlockOperation + npu_op : vapi.NpuBlockOperation + """ clip_min_quant = npu_op.activation.min clip_max_quant = npu_op.activation.max @@ -349,13 +353,14 @@ def _convert_clip_bounds(npu_op): npu_op.activation.max = clip_max_actual -def translate_ethosu_conv2d(tir_extern_call): - """This function will translate a tir extern_call - as produced by Relay to TIR compilation. +def translate_ethosu_conv2d(tir_call_extern: tvm.tir.Call) -> Tuple[vapi.NpuConv2DOperation, int]: + """This function will translate a TIR call_extern + as produced by NPU Relay to TIR compilation. + Parameters ---------- - tir_extern_call : tvm.tir.Call - This should be an tir external call that has a agreed upon ordering + tir_call_extern : tvm.tir.Call + This should be a TIR call_extern that has a agreed upon ordering for TIR Compiler. See Serial2DConvolution in tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. @@ -365,15 +370,18 @@ def translate_ethosu_conv2d(tir_extern_call): The vela object containing the params of ethosu_conv2d weights_zero_point : int The zero point of the weights + """ - # We skip the first element as it is the extern_call function name - serial_object = spec.create_serial_object(spec.Serial2DConvolution, tir_extern_call.args[1:]) + # We skip the first element as it is the call_extern function name + serial_object = spec.create_serial_object(spec.Serial2DConvolution, tir_call_extern.args[1:]) return _create_npu_op_conv2d(serial_object) -def _create_npu_op_conv2d(serial_2d_convolution): +def _create_npu_op_conv2d( + serial_2d_convolution: spec.Serial2DConvolution, +) -> Tuple[vapi.NpuConv2DOperation, int]: """This is a helper function to capture a list - of arguments to create Vela NpuConv2DOperation object + of arguments to create Vela NpuConv2DOperation object. """ npu_conv2d_op = vapi.NpuConv2DOperation() npu_conv2d_op.ifm = _create_npu_feature_map(serial_2d_convolution.ifm) @@ -392,8 +400,8 @@ def _create_npu_op_conv2d(serial_2d_convolution): _convert_clip_bounds(npu_conv2d_op) npu_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_convolution.upscale) - target_accel_type = vela_api.get_target_accel_type() # type: ignore - block_config = vela_api.get_optimal_block_config(npu_conv2d_op, target_accel_type) + accel_config = vela_api.get_accelerator_config() + block_config = vela_api.get_optimal_block_config(npu_conv2d_op, accel_config) npu_conv2d_op.block_config = block_config weights_shape_ohwi = [ npu_conv2d_op.ofm.shape.depth, @@ -450,16 +458,16 @@ def _create_npu_op_depthwise_conv2d(serial_2d_depthwise): _convert_clip_bounds(npu_depthwise_conv2d_op) npu_depthwise_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_depthwise.upscale) - target_accel_type = vela_api.get_target_accel_type() - block_config = vela_api.get_optimal_block_config(npu_depthwise_conv2d_op, target_accel_type) + target_accel_config = vela_api.get_accelerator_config() + block_config = vela_api.get_optimal_block_config(npu_depthwise_conv2d_op, target_accel_config) npu_depthwise_conv2d_op.block_config = block_config return npu_depthwise_conv2d_op, weights_zero_point -def _create_npu_feature_map(serial_feature_map): +def _create_npu_feature_map(serial_feature_map: spec.SerialFeatureMap) -> vapi.NpuFeatureMap: """This is a helper function to capture a list - of arguments to create Vela NpuFeatureMap object + of arguments to create Vela NpuFeatureMap object. """ layout_map = {"NHWC": vapi.NpuLayout.NHWC, "NHCWB16": vapi.NpuLayout.NHCWB16} datatype_map = { @@ -476,14 +484,14 @@ def _create_npu_feature_map(serial_feature_map): nfm = vapi.NpuFeatureMap() nfm.data_type = datatype_map[data_type] nfm.shape = vapi.NpuShape3D( - int(serial_feature_map.height.value), - int(serial_feature_map.width.value), - int(serial_feature_map.channels.value), + int(serial_feature_map.height), + int(serial_feature_map.width), + int(serial_feature_map.channels), ) nfm.tiles = vapi.NpuTileBox( - int(serial_feature_map.tile_height_0.value), - int(serial_feature_map.tile_height_1.value), - int(serial_feature_map.tile_width_0.value), + int(serial_feature_map.tile_height_0), + int(serial_feature_map.tile_height_1), + int(serial_feature_map.tile_width_0), [ serial_feature_map.tile_address_0, serial_feature_map.tile_address_1, @@ -496,81 +504,75 @@ def _create_npu_feature_map(serial_feature_map): ) nfm.layout = layout_map[layout] nfm.strides = vapi.NpuShape3D( - int(serial_feature_map.stride_h.value), - int(serial_feature_map.stride_w.value), - int(serial_feature_map.stride_c.value), + int(serial_feature_map.stride_h), + int(serial_feature_map.stride_w), + int(serial_feature_map.stride_c), ) return nfm -def _create_npu_kernel(serial_kernel): +def _create_npu_kernel(serial_kernel: spec.SerialKernel) -> vapi.NpuKernel: """This is a helper function to capture a list - of arguments to create Vela NpuKernel object + of arguments to create Vela NpuKernel object. """ nknl = vapi.NpuKernel( - w=int(serial_kernel.width.value), - h=int(serial_kernel.height.value), - stride_x=int(serial_kernel.stride_w.value), - stride_y=int(serial_kernel.stride_h.value), - dilation_x=int(serial_kernel.dilation_w.value), - dilation_y=int(serial_kernel.dilation_h.value), + w=int(serial_kernel.width), + h=int(serial_kernel.height), + stride_x=int(serial_kernel.stride_w), + stride_y=int(serial_kernel.stride_h), + dilation_x=int(serial_kernel.dilation_w), + dilation_y=int(serial_kernel.dilation_h), ) return nknl -def _create_npu_address_range(serial_address_range): +def _create_npu_address_range( + serial_address_range: spec.SerialAddressRange, +) -> vapi.NpuAddressRange: """This is a helper function to capture a list - of arguments to create Vela NpuAddressRange object + of arguments to create Vela NpuAddressRange object. """ addr_range = vapi.NpuAddressRange( # region will be updated later region=0, address=serial_address_range.address, - length=int(serial_address_range.length.value), + length=int(serial_address_range.length), ) return addr_range def _create_npu_quantization( - scale, - zero_point, -): + scale: Union[tvm.tir.FloatImm, float], + zero_point: Union[tvm.tir.IntImm, int], +) -> vapi.NpuQuantization: """This is a helper function to capture a list - of arguments to create Vela NpuQuantization object + of arguments to create Vela NpuQuantization object. """ - # Scale could be an ndarray if per-channel quantization is available - if not isinstance(scale, tvm.tir.expr.Load): - if isinstance(scale.value, float): - scale = np.single(scale.value) - else: - assert isinstance(scale.value.value, float) - scale = np.single(scale.value.value) - q_params = vapi.NpuQuantization(scale_f32=scale, zero_point=zero_point.value) - return q_params + return vapi.NpuQuantization(scale_f32=float(scale), zero_point=int(zero_point)) def _create_npu_weights_zero_point( - zero_point, -): - """This is a helper function to capture the weights zero point""" - return zero_point.value + zero_point: Union[int, tvm.tir.IntImm], +) -> int: + """This is a helper function to capture the weights zero point.""" + return int(zero_point) -def _create_npu_padding(serial_padding): +def _create_npu_padding(serial_padding: spec.SerialPadding) -> vapi.NpuPadding: """This is a helper function to capture a list - of arguments to create Vela NpuPadding object""" + of arguments to create Vela NpuPadding object.""" padding = vapi.NpuPadding( - top=int(serial_padding.top.value), - left=int(serial_padding.left.value), - bottom=int(serial_padding.bottom.value), - right=int(serial_padding.right.value), + top=int(serial_padding.top), + left=int(serial_padding.left), + bottom=int(serial_padding.bottom), + right=int(serial_padding.right), ) return padding -def _create_npu_activation(serial_activation): +def _create_npu_activation(serial_activation: spec.SerialActivation) -> vapi.NpuActivation: """This is a helper function to capture a list - of arguments to create Vela NpuActivation object""" + of arguments to create Vela NpuActivation object.""" if serial_activation.op == "NONE": return None if ( @@ -587,16 +589,16 @@ def _create_npu_activation(serial_activation): op = str(serial_activation.op.value) assert op in op_map.keys() act_op = vapi.NpuActivation(op_map[op]) - act_op.min = int(serial_activation.clip_min.value) - act_op.max = int(serial_activation.clip_max.value) + act_op.min = int(serial_activation.clip_min) + act_op.max = int(serial_activation.clip_max) return act_op def _create_npu_resampling_mode( - mode, -): + mode: str, +) -> vapi.NpuResamplingMode: """This is a helper function to capture a list - of arguments to create Vela NpuResamplingMode object""" + of arguments to create Vela NpuResamplingMode object.""" mode_map = { "NONE": vapi.NpuResamplingMode.NONE, "NEAREST": vapi.NpuResamplingMode.NEAREST, diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index 6523352a0eea..69095e43416e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -27,6 +27,7 @@ import numpy as np # type: ignore from ethosu.vela import api as vapi # type: ignore +import tvm from tvm.relay.backend.contrib.ethosu import util # type: ignore from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator as tirtocs @@ -45,7 +46,7 @@ def get_optimal_block_config( - npu_op: vapi.NpuOperation, accel_type: vapi.NpuAccelerator + npu_op: vapi.NpuOperation, accel_config: vapi.NpuAccelerator ) -> vapi.NpuShape3D: """ "The NPU's unit of work is known as a block. It will fetch block(s) from Input @@ -58,15 +59,15 @@ def get_optimal_block_config( ---------- npu_op : ethosu.vela.api.NpuOperation The NPU operation and its params - accel_type : ethosu.vela.api.NpuAccelerator - The NPU accelerator variant + accel_config : ethosu.vela.api.NpuAccelerator + The NPU accelerator config Returns ------- ethosu.vela.api.NpuShape3D : The optimal block config for the operator """ - all_valid_block_configs = vapi.npu_find_block_configs(npu_op, accel_type) + all_valid_block_configs = vapi.npu_find_block_configs(npu_op, accel_config) return _get_optimal_block_config(all_valid_block_configs) @@ -112,7 +113,9 @@ def _get_optimal_block_config(all_valid_block_configs: List[vapi.NpuShape3D]) -> return max_area_depth_block_configs[0] -def encode_weights(tir_extern_call, values, accel_type): +def encode_weights( + tir_extern_call: tvm.tir.Call, values: np.ndarray, accel_config: vapi.NpuAccelerator +): """This is an API function to compress weights by passing a tir_extern_call to NPU Convolution operation and values. @@ -122,8 +125,8 @@ def encode_weights(tir_extern_call, values, accel_type): tir_extern_call to NPU Convolution operation values : numpy.ndarray The constant flattened weight data in OHWI layout - accel_type : ethosu.vela.api.NpuAccelerator - The NPU accelerator variant + accel_config : ethosu.vela.api.NpuAccelerator + The NPU accelerator config Returns ------- @@ -137,7 +140,7 @@ def encode_weights(tir_extern_call, values, accel_type): op = str(tir_extern_call.args[0].value) assert op in supported_ops.keys() npu_op, weights_zero_point = supported_ops[op](tir_extern_call) - block_config = get_optimal_block_config(npu_op, accel_type) + block_config = get_optimal_block_config(npu_op, accel_config) # The weight layout is assumed to be flat OHWI, always. assert len(values.shape) == 1 is_depthwise = op == "ethosu_depthwise_conv2d" @@ -157,7 +160,7 @@ def encode_weights(tir_extern_call, values, accel_type): ifm_bitdepth=npu_op.ifm.data_type.size_in_bits(), block_depth=block_config.depth, dilation=(npu_op.kernel.dilation_x, npu_op.kernel.dilation_y), - accel_type=accel_type, + accel_config=accel_config, is_depthwise=is_depthwise, ) @@ -169,7 +172,7 @@ def compress_weights( ifm_bitdepth: int, block_depth: int, dilation: Tuple[int, int], - accel_type: vapi.NpuAccelerator, + accel_config: vapi.NpuAccelerator, is_depthwise: Optional[bool] = False, ) -> bytearray: """The NPU requires the weights to be compressed @@ -191,8 +194,8 @@ def compress_weights( The depth of the optimal block config for the operator dilation : tuple A tuple of 2 elements indicating dilation in h and w - accel_type : ethosu.vela.api.NpuAccelerator - The NPU accelerator variant + accel_config : ethosu.vela.api.NpuAccelerator + The NPU accelerator config is_depthwise : bool, Optional This indicates whether the weights are compressed for depthwise convolution @@ -215,7 +218,7 @@ def compress_weights( ] block_traversal = calculate_block_traversal_mode(is_depthwise, shape_ohwi, ifm_bitdepth) compressed_weights = vapi.npu_encode_weights( - accelerator=accel_type, + accelerator=accel_config, weights_volume=weights_ohwi, dilation_xy=dilation, ifm_bitdepth=ifm_bitdepth, @@ -361,15 +364,24 @@ def _calculate_hw_bias_scales( return hw_bias_scales -def get_target_accel_type(): - """This is a helper function to convert cli accelerator type str argument - to NpuAccelerator""" +def get_accelerator_config() -> vapi.NpuAccelerator: + """Get the configuration of the NPU accelerator. + + The configuration string provided as a compiler option is converted into + an NpuAccelerator object. Valid configuration strings: + - 'ethos-u55-256' + - 'ethos-u55-128' + - 'ethos-u55-64' + - 'ethos-u55-32' + + """ npu_accel_str_map = { "ethos-u55-256": vapi.NpuAccelerator.Ethos_U55_256, "ethos-u55-128": vapi.NpuAccelerator.Ethos_U55_128, "ethos-u55-64": vapi.NpuAccelerator.Ethos_U55_64, "ethos-u55-32": vapi.NpuAccelerator.Ethos_U55_32, } - accel_type_str = util.get_accelerator_config() - assert accel_type_str in npu_accel_str_map.keys(), f"{accel_type_str} is not supported" - return npu_accel_str_map[accel_type_str] + compiler_attrs = tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")() + accel_config_str = compiler_attrs.accelerator_config + assert accel_config_str in npu_accel_str_map.keys(), f"{accel_config_str} is not supported" + return npu_accel_str_map[accel_config_str] diff --git a/src/relay/backend/contrib/ethosu/to_te_graph.cc b/src/relay/backend/contrib/ethosu/to_te_graph.cc deleted file mode 100644 index 9646c39da089..000000000000 --- a/src/relay/backend/contrib/ethosu/to_te_graph.cc +++ /dev/null @@ -1,234 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/contrib/ethosu/to_te_graph.cc - * \brief Lower a Relay function to a TE graph. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "../../compile_engine.h" -#include "../../utils.h" - -namespace tvm { -namespace relay { -namespace contrib { -namespace ethosu { - -/*! \brief Node container to represent a Tensor Expression graph. */ -class TEGraphNode : public Object { - public: - /* \brief The inputs to the graph */ - tvm::Array inputs; - /* \brief The outputs to the graph */ - tvm::Array outputs; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("inputs", &inputs); - v->Visit("outputs", &outputs); - } - - static constexpr const char* _type_key = "relay.TEGraph"; - TVM_DECLARE_FINAL_OBJECT_INFO(TEGraphNode, Object); -}; - -class TEGraph : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TEGraph, ObjectRef, TEGraphNode); -}; - -TVM_REGISTER_NODE_TYPE(TEGraphNode); - -Array GetShape(const Array& shape) { - // for now, we always use int32 shape when possible - // even if the result of shape inference becomes int64. - Array res; - for (IndexExpr val : shape) { - const int64_t* pval = tir::as_const_int(val); - if (pval != nullptr) { -#ifndef TVM_INDEX_DEFAULT_I64 - ICHECK_LE(pval[0], std::numeric_limits::max()); - ICHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(IntImm(DataType::Int(32), *pval)); -#else - res.push_back(val); -#endif // TVM_INDEX_DEFAULT_I64 - } else if (val->IsInstance()) { - res.push_back(val.as()->ToVar()); - } else { - res.push_back(val); - } - } - return res; -} - -class RelayToTE : public backend::MemoizedExprTranslator> { - public: - RelayToTE() = default; - - TEGraph Lower(const Function& prim_func) { - auto graph_node = make_object(); - for (Var param : prim_func->params) { - Array inputs; - if (const auto* ttype = param->checked_type().as()) { - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - graph_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - ICHECK(ttype != nullptr); - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - graph_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } - } - memo_[param] = inputs; - } - graph_node->outputs = this->VisitExpr(prim_func->body); - return TEGraph(graph_node); - } - - Array VisitExpr_(const VarNode* op) final { - LOG(FATAL) << "Free variable " << op->name_hint(); - return {}; - } - - Array VisitExpr_(const ConstantNode* op) final { - using tir::make_const; - ICHECK(op->is_scalar()); - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "compile_engine_const", topi::kBroadcast); - return {value}; - } - - Array VisitExpr_(const CallNode* call_node) final { - static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); - ICHECK(flower_call) << "relay.backend.lower_call is not registered."; - - Array inputs; - int count_tuple = 0; - for (Expr arg : call_node->args) { - if (arg->checked_type().as()) { - ++count_tuple; - } - for (te::Tensor tensor : VisitExpr(arg)) { - inputs.push_back(tensor); - } - } - if (count_tuple) { - ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; - } - - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; - Op op = Downcast(call_node->op); - - Array outputs; - LoweredOutput lowered_out = - (*flower_call)(GetRef(call_node), inputs, tvm::Target("llvm")); - outputs = lowered_out->outputs; - - if (outputs.size() != 1) { - const auto* tuple_type = call_node->checked_type().as(); - ICHECK(tuple_type) << "Expect output to be a tuple type"; - ICHECK_EQ(tuple_type->fields.size(), outputs.size()); - } - return outputs; - } - - Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; - return Array(); - } - - Array VisitExpr_(const LetNode* op) final { - Array val = VisitExpr(op->value); - ICHECK(!memo_.count(op->var)); - memo_[op->var] = val; - return VisitExpr(op->body); - } - - Array VisitExpr_(const TupleNode* op) final { - Array fields; - for (Expr field : op->fields) { - ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; - Array res = VisitExpr(field); - ICHECK_EQ(res.size(), 1); - fields.push_back(res[0]); - } - return fields; - } - - Array VisitExpr_(const TupleGetItemNode* op) final { - const auto* tuple_type = op->tuple->type_as(); - Array tuple = VisitExpr(op->tuple); - ICHECK_EQ(tuple_type->fields.size(), tuple.size()); - ICHECK_GE(op->index, 0); - ICHECK_LT(static_cast(op->index), tuple.size()); - return {tuple[op->index]}; - } -}; - -TVM_REGISTER_GLOBAL("relay.backend.contrib.ethosu.LowerToTE") - .set_body_typed([](Function prim_func) { return RelayToTE().Lower(prim_func); }); - -} // namespace ethosu -} // namespace contrib -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index d0e83765928a..ec87cfc98931 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -111,8 +111,10 @@ Array GetShape(const Array& shape) { // Construct a schedule for a given Relay primitive function and target. class ScheduleBuilder : public backend::MemoizedExprTranslator> { public: - explicit ScheduleBuilder(Target target) - : target_(target), device_copy_op_(Op::Get("device_copy")) { + explicit ScheduleBuilder(Target target, bool create_schedule = true) + : target_(target), + device_copy_op_(Op::Get("device_copy")), + create_schedule_(create_schedule) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); } @@ -149,7 +151,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator auto prim_fn_var = GlobalVar(prim_fn_name); prim_fn_var->checked_type_ = prim_func->checked_type(); - ICHECK(anchor_op_.defined()); // Fusion over tupled results may leave identity relationships // between inputs and outputs, and those should not be scheduled. // Hence schedule only non PlaceholderOp outputs. @@ -162,7 +163,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator te::Schedule schedule; // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr) { + if (anchor_attrs_.as() == nullptr && create_schedule_) { if (use_auto_scheduler_) { const auto* fauto_schedule = runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); @@ -259,17 +260,19 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator impl = lowered_out->implementation; } - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern >= anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - anchor_implementation_ = impl; + if (create_schedule_) { + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; + anchor_implementation_ = impl; + } } if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as(); @@ -334,6 +337,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. const Op& device_copy_op_; + bool create_schedule_; }; /*! @@ -667,6 +671,12 @@ std::string GetUniqueName(std::string name, std::unordered_map return name; } +TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { + return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) { + return name; + }); +}); + } // namespace tec } // namespace relay } // namespace tvm diff --git a/tests/python/contrib/test_ethosu/test_attr_passing.py b/tests/python/contrib/test_ethosu/test_attr_passing.py index a2fbe1888d2a..6b99a5c1e540 100644 --- a/tests/python/contrib/test_ethosu/test_attr_passing.py +++ b/tests/python/contrib/test_ethosu/test_attr_passing.py @@ -28,7 +28,9 @@ def test_compiler_attr(): } with tvm.transform.PassContext(opt_level=3, config={"relay.ext.ethosu.options": config}): with tvm.target.Target("c -device=micro_dev"): - assert util.get_accelerator_config() == config["accelerator_config"] + compiler_attrs = tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")() + accel_config_str = compiler_attrs.accelerator_config + assert accel_config_str == config["accelerator_config"] def test_compiler_attr_default(): @@ -37,7 +39,9 @@ def test_compiler_attr_default(): } with tvm.transform.PassContext(opt_level=3): with tvm.target.Target("c -device=micro_dev"): - assert util.get_accelerator_config() == default_config["accelerator_config"] + compiler_attrs = tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")() + accel_config_str = compiler_attrs.accelerator_config + assert accel_config_str == default_config["accelerator_config"] if __name__ == "__main__": diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 60ed352edcfd..5b60102162be 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -64,10 +64,10 @@ def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, def test_weight_stream_only(): - def _planner(te_graph, const_dict, sch): - weights = te_graph.inputs[1] - bias = te_graph.inputs[2] - out = te_graph.outputs[0] + def _planner(cached_func, const_dict, sch): + weights = cached_func.inputs[1] + bias = cached_func.inputs[2] + out = cached_func.outputs[0] conv_compute = Convolution2DCompute.from_output(out) co = conv_compute.split(sch, 3, 2) cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d]) @@ -208,10 +208,10 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle def test_mixed_read(): - def _planner(te_graph, const_dict, sch): - weight = te_graph.inputs[4] - scale_bias = te_graph.inputs[5] - out = te_graph.outputs[0] + def _planner(cached_func, const_dict, sch): + weight = cached_func.inputs[4] + scale_bias = cached_func.inputs[5] + out = cached_func.outputs[0] conv_compute = Convolution2DCompute.from_output(out) co = conv_compute.split(sch, 3, 2) cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d]) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 8077271ed496..b04059011e8e 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -81,10 +81,10 @@ def test_inline_no_ops(): func = relay.Function(relay.analysis.free_vars(relu2), relu2) func = run_opt_pass(func, relay.transform.InferType()) - te_graph = lower_to_te(func) - sch = te.create_schedule([te_graph.outputs[0].op]) - inline_no_ops(te_graph, sch) - reshape_tensor = te_graph.outputs[0].op.input_tensors[0] + cached_func = lower_to_te(func) + sch = te.create_schedule([cached_func.outputs[0].op]) + inline_no_ops(cached_func, sch) + reshape_tensor = cached_func.outputs[0].op.input_tensors[0] slice_tensor = reshape_tensor.op.input_tensors[0].op.input_tensors[0] assert sch[reshape_tensor].attach_type == AttachType.kInline assert sch[slice_tensor].attach_type == AttachType.kInline @@ -114,11 +114,11 @@ def test_copy_constants(): func = run_opt_pass(func, relay.transform.InferType()) func, const_dict = extract_constants(func) - te_graph = lower_to_te(func) + cached_func = lower_to_te(func) - sch = te.create_schedule([te_graph.outputs[0].op]) + sch = te.create_schedule([cached_func.outputs[0].op]) planner = copy_constants() - planner(te_graph, const_dict, sch) + planner(cached_func, const_dict, sch) assert len(sch.stages) == 21 assert ".global" in sch.stages[5].op.name assert ".global" in sch.stages[7].op.name diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index 02c305387d45..cf845db2b43b 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -354,18 +354,17 @@ def create_mock(test_vec): max = np.iinfo(ifm_dtype).max min = np.iinfo(ifm_dtype).min values = np.random.randint(min, max, test_vec["shape"], ifm_dtype) - compressed_weights = vela_api.compress_weights( + vela_api.compress_weights( weights=values, weights_zp=test_vec["zero_point"], weights_layout=test_vec["layout"], ifm_bitdepth=ifm_bitdepth, block_depth=test_vec["block_depth"], dilation=test_vec["dilation"], - accel_type=test_vec["accel"], + accel_config=test_vec["accel"], is_depthwise=test_vec["is_depthwise"], ) return mock_npu_encode_weights - return None for tv in test_vecs: mock_obj = create_mock(tv) From 8bd845deed9d3c7f0def2c98489d239f7663dea5 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 12 Oct 2021 19:41:45 -0500 Subject: [PATCH 15/84] [Simplifier] Add printing of SplitExprNode and SumExprNode (#9262) --- src/arith/canonical_simplify.cc | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 94db659e25c9..29153037b9fa 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -535,6 +535,40 @@ void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) { this->AddToSelf(other->base * scale); } +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + auto factor_str = [](int64_t f) { + return f == SplitExprNode::kPosInf ? std::string("+inf") : std::to_string(f); + }; + p->stream << "split("; + p->Print(op->index); + p->stream << ", lower=" << factor_str(op->lower_factor) + << ", upper=" << factor_str(op->upper_factor) << ", scale=" << op->scale + << ", div_mode="; + switch (op->div_mode) { + // No "default", so that the compiler will emit a warning if more div modes are + // added that are not covered by the switch. + case kTruncDiv: + p->stream << "truncdiv"; + break; + case kFloorDiv: + p->stream << "floordiv"; + break; + } + p->stream << ')'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "sum(base=" << op->base; + for (const SplitExpr& s : op->args) { + p->stream << ", "; + p->Print(s); + } + }); + // Sub-class RewriteSimplifier::Impl to take benefit of // rewriter for condition simplification etc. class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { From b5d863ca52875e8234af69c85e6ebdc1e0795b9b Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 12 Oct 2021 19:07:20 -0700 Subject: [PATCH 16/84] fix docs (#9266) --- tests/scripts/task_python_docs.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 2eb471cbc69f..765c84137730 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -83,11 +83,11 @@ cd .. rm -rf _docs mv docs/_build/html _docs rm -f _docs/.buildinfo -mkdir -p _docs/api -mv docs/doxygen/html _docs/api/doxygen -mv jvm/core/target/site/apidocs _docs/api/javadoc +mkdir -p _docs/reference/api +mv docs/doxygen/html _docs/reference/api/doxygen +mv jvm/core/target/site/apidocs _docs/reference/api/javadoc # mv rust/target/doc _docs/api/rust -mv web/dist/docs _docs/api/typedoc +mv web/dist/docs _docs/reference/api/typedoc echo "Start creating the docs tarball.." # make the tarball From 3229cb329254764499dd672bb28fd9685ecd6a2e Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 12 Oct 2021 21:08:54 -0500 Subject: [PATCH 17/84] [LLVM] Treat scalars as single-lane vectors in CreateVecConcat (#9264) LLVM differentiates between `<1 x ty>` and `ty`, while TVM does not. Make sure that a bunch of TVM scalars can be concatenated into a vector when generating LLVM IR. --- src/target/llvm/codegen_llvm.cc | 14 ++++++++++++++ .../python/unittest/test_target_codegen_llvm.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 12fbf2c3e42c..c94c5a685d1b 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -626,6 +626,20 @@ llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { } llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { + // To allow creating vectors from scalars, convert any scalars in "vecs" to single-lane + // LLVM vector types. + for (size_t i = 0, e = vecs.size(); i != e; ++i) { + llvm::Value* v = vecs[i]; + if (!v->getType()->isVectorTy()) { +#if TVM_LLVM_VERSION >= 110 + llvm::Type* vec_ty = llvm::FixedVectorType::get(v->getType(), 1); +#else + llvm::Type* vec_ty = llvm::VectorType::get(v->getType(), 1); +#endif + vecs[i] = builder_->CreateInsertElement(llvm::UndefValue::get(vec_ty), v, ConstInt32(0)); + } + } + // concat vector, tree shape reduction int total_lanes = 0; diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 8c8d601672ac..5a1b33ae10b1 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -885,5 +885,21 @@ def check_llvm(use_file): check_llvm(use_file=False) +@tvm.testing.requires_llvm +def test_llvm_scalar_concat(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.decl_buffer((1,), "int32x2") + s = tvm.tir.Shuffle([x, y], [0, 1]) + f = tvm.tir.PrimFunc([x, y, z], z.vstore(0, s)) + + mod = tvm.ir.IRModule.from_expr(f.with_attr("global_symbol", "codegen_scalar_concat")) + + # This will crash in LLVM codegen if CodeGenLLVM::CreateVecConcat doesn't convert + # scalars to single-lane LLVM vectors. + with tvm.transform.PassContext(config={"tir.disable_assert": True}): + m = tvm.build(mod, [x, y, z], target="llvm") + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 617c71273892216ca727466ee40bd263801a7b39 Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Tue, 12 Oct 2021 21:10:19 -0500 Subject: [PATCH 18/84] [TIR] Added PrettyPrint of ProducerStore/ProducerRealize nodes (#9259) --- src/printer/text_printer.h | 7 +++++ src/printer/tir_text_printer.cc | 47 +++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 3514f3228e27..a2178167b2e3 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -276,6 +276,8 @@ class TIRTextPrinter : public StmtFunctor, std::unordered_map memo_var_; /*! \brief Map from Buffer to Doc */ std::unordered_map memo_buf_; + /*! \brief Map from Buffer to Doc */ + std::unordered_map memo_producer_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; @@ -321,7 +323,9 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitStmt_(const AssertStmtNode* op) override; Doc VisitStmt_(const StoreNode* op) override; Doc VisitStmt_(const BufferStoreNode* op) override; + Doc VisitStmt_(const ProducerStoreNode* op) override; Doc VisitStmt_(const BufferRealizeNode* op) override; + Doc VisitStmt_(const ProducerRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; @@ -342,7 +346,9 @@ class TIRTextPrinter : public StmtFunctor, Doc PrintIterVar(const IterVarNode* op); Doc PrintRange(const RangeNode* op); Doc PrintBuffer(const BufferNode* op); + Doc PrintProducer(const DataProducerNode* op); Doc BufferNode2Doc(const BufferNode* op, Doc doc); + Doc DataProducerNode2Doc(const DataProducerNode* op, Doc doc); Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } Doc PrintBufferRegion(const BufferRegionNode* op); @@ -361,6 +367,7 @@ class TIRTextPrinter : public StmtFunctor, Doc GetUniqueName(std::string prefix); Doc AllocVar(const Var& var); Doc AllocBuf(const Buffer& buffer); + Doc AllocProducer(const DataProducer& buffer); /*! * \brief special method to render vectors of docs with a separator * \param vec vector of docs diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index fa132f079793..302c4491cebe 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -65,6 +65,8 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) { return PrintRange(node.as()); } else if (node->IsInstance()) { return PrintBuffer(node.as()); + } else if (node->IsInstance()) { + return PrintProducer(node.as()); } else if (node->IsInstance()) { return PrintString(node.as()); } else if (node->IsInstance()) { @@ -199,6 +201,19 @@ Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) { } } +Doc TIRTextPrinter::PrintProducer(const DataProducerNode* op) { + const DataProducer& prod = GetRef(op); + + if (meta_->InMeta(prod)) { + return meta_->GetMetaNode(prod); + } else if (memo_producer_.count(prod)) { + return memo_producer_[prod]; + } else { + memo_producer_[prod] = AllocProducer(prod); + return DataProducerNode2Doc(op, memo_producer_[prod]); + } +} + Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { doc << Doc::Text(": Buffer(") << Print(buf->data) << ", " << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", " << Print(buf->strides); @@ -220,6 +235,11 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { return doc << ")"; } +Doc TIRTextPrinter::DataProducerNode2Doc(const DataProducerNode* prod, Doc doc) { + return doc << Doc::Text(": DataProducer(") << Print(prod->GetNameHint()) << ", " + << PrintDType(prod->GetDataType()) << ", " << Print(prod->GetShape()) << ")"; +} + Doc TIRTextPrinter::PrintBufferRegion(const BufferRegionNode* op) { Doc doc; doc << Print(op->buffer) << "["; @@ -439,6 +459,12 @@ Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const ProducerStoreNode* op) { + Doc doc; + doc << Print(op->producer) << Print(op->indices) << " = " << Print(op->value); + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { Doc doc; doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", " @@ -446,6 +472,13 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const ProducerRealizeNode* op) { + Doc doc; + doc << "producer_realize(" << Print(op->producer) << ", " << Print(op->bounds) << ", " + << Print(op->condition) << ", " << PrintBody(op->body) << ")"; + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; auto scope = GetPtrStorageScope(op->buffer_var); @@ -709,6 +742,20 @@ Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) { return val; } +Doc TIRTextPrinter::AllocProducer(const DataProducer& producer) { + const auto& it = memo_producer_.find(producer); + if (it != memo_producer_.end()) { + return it->second; + } + std::string name = producer->GetNameHint(); + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "tensor_" + name; + } + Doc val = GetUniqueName(name); + memo_producer_[producer] = val; + return val; +} + Doc TIRTextPrinter::PrintSep(const std::vector& vec, const Doc& sep) { Doc seq; if (vec.size() != 0) { From 3ee8efa43c22b246076cb2b9019544ed5619db2d Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Tue, 12 Oct 2021 21:12:02 -0500 Subject: [PATCH 19/84] [TIR] Minor refactor to tir.transform.StorageFlatten (#9260) Expressed each step as a separate `transform::Pass`, so they can be used/inspected individually. --- src/tir/transforms/storage_flatten.cc | 114 +++++++++++++++++++++----- 1 file changed, 92 insertions(+), 22 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 6a3ce596c2fe..ccc660509ca1 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -60,6 +60,19 @@ using runtime::ThreadScope; */ class BufferShapeLegalize : public StmtExprMutator { public: + static transform::Pass Pass() { + auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { + IRVisitorWithAnalyzer bound_analyzer; + + bound_analyzer(func->body); + + auto fptr = func.CopyOnWrite(); + fptr->body = BufferShapeLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferShapeLegalize", {}); + } + explicit BufferShapeLegalize(const Map& extern_buffer_map, IRVisitorWithAnalyzer* bound_analyzer) : bound_analyzer_(bound_analyzer) { @@ -383,6 +396,19 @@ class BufferShapeLegalize : public StmtExprMutator { */ class BufferStrideLegalize : public StmtExprMutator { public: + static transform::Pass Pass() { + auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { + IRVisitorWithAnalyzer bound_analyzer; + + bound_analyzer(func->body); + + auto fptr = func.CopyOnWrite(); + fptr->body = BufferStrideLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferStrideLegalize", {}); + } + explicit BufferStrideLegalize(const Map& extern_buffer_map, IRVisitorWithAnalyzer* bound_analyzer) : bound_analyzer_(bound_analyzer) { @@ -565,6 +591,15 @@ class BufferStrideLegalize : public StmtExprMutator { */ class ThreadScopePropagate : public StmtExprMutator { public: + static transform::Pass Pass() { + auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { + auto fptr = func.CopyOnWrite(); + fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body)); + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.ThreadScopePropagate", {}); + } + explicit ThreadScopePropagate(const Map& extern_buffer_map) { // External buffers shouldn't be overwritten, even if they have a // BufferRealizeNode. @@ -718,6 +753,19 @@ class ThreadScopePropagate : public StmtExprMutator { */ class BufferBindUnwrapper : public StmtExprMutator { public: + static transform::Pass Pass() { + auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { + IRVisitorWithAnalyzer bound_analyzer; + + bound_analyzer(func->body); + + auto fptr = func.CopyOnWrite(); + fptr->body = BufferBindUnwrapper(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferBindUnwrapper", {}); + } + explicit BufferBindUnwrapper(const Map& extern_buffer_map, IRVisitorWithAnalyzer* bound_analyzer) : bound_analyzer_(bound_analyzer) { @@ -1030,6 +1078,20 @@ class BufferBindUnwrapper : public StmtExprMutator { class StorageFlattener : public StmtExprMutator { public: + static transform::Pass Pass(int cache_line_size, bool create_bound_attributes) { + auto pass_func = [=](PrimFunc func, IRModule m, transform::PassContext ctx) { + IRVisitorWithAnalyzer bound_analyzer; + + bound_analyzer(func->body); + + auto fptr = func.CopyOnWrite(); + fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, + &bound_analyzer)(std::move(fptr->body)); + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {}); + } + explicit StorageFlattener(const Map& extern_buffer_map, int cache_line_size, bool create_bound_attributes, IRVisitorWithAnalyzer* bound_analyzer) : bound_analyzer_(bound_analyzer), create_bound_attributes_(create_bound_attributes) { @@ -1355,6 +1417,19 @@ class StorageFlattener : public StmtExprMutator { */ class AssertSimplifier : public StmtMutator { public: + static transform::Pass Pass() { + auto pass_func = [=](PrimFunc func, IRModule m, transform::PassContext ctx) { + IRVisitorWithAnalyzer bound_analyzer; + + bound_analyzer(func->body); + + auto fptr = func.CopyOnWrite(); + fptr->body = AssertSimplifier(&bound_analyzer)(std::move(fptr->body)); + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.AssertSimplifier", {}); + } + explicit AssertSimplifier(IRVisitorWithAnalyzer* bound_analyzer) : bound_analyzer_(bound_analyzer) {} @@ -1409,30 +1484,25 @@ class AssertSimplifier : public StmtMutator { // We do support a few relaxed case, such as binding a // region with shape [1, 1, n, m] to buffer with shape [n, m] PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { - // Only apply this pass to TIR from TE schedules + // Only apply this pass to TIR from TE schedules. Because this is a + // per-function attribute, we can't just check it once for the + // entire module and apply the Sequential transform. Optional from_legacy_te_schedule = func->GetAttr("from_legacy_te_schedule", Bool(false)); if (from_legacy_te_schedule.value()) { - auto fptr = func.CopyOnWrite(); - - IRVisitorWithAnalyzer bound_analyzer; - bound_analyzer(fptr->body); - - fptr->body = BufferShapeLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); - - auto stride_legalize = BufferStrideLegalize(fptr->buffer_map, &bound_analyzer); - fptr->body = stride_legalize(std::move(fptr->body)); - fptr->buffer_map = stride_legalize.UpdatedExternBufferMap(); - - fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body)); - - fptr->body = BufferBindUnwrapper(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); - - fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, - &bound_analyzer)(std::move(fptr->body)); - - fptr->body = AssertSimplifier(&bound_analyzer)(std::move(fptr->body)); - - return func; + auto seq = transform::Sequential( + { + BufferShapeLegalize::Pass(), + BufferStrideLegalize::Pass(), + ThreadScopePropagate::Pass(), + BufferBindUnwrapper::Pass(), + StorageFlattener::Pass(cache_line_size, create_bound_attributes), + AssertSimplifier::Pass(), + }, + "tir.StorageFlatten_impl"); + GlobalVar dummy_func_name("dummy_func"); + IRModule mod(Map({{dummy_func_name, func}})); + mod = seq(mod); + return Downcast(mod->Lookup(dummy_func_name)); } else { return func; } From 8a3fcc40a4699bf6df13d5409b043819c087ee5f Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Wed, 13 Oct 2021 03:31:14 +0100 Subject: [PATCH 20/84] [TVMC] Compose target options from target registry (#9218) * [TVMC] Split common tvmc test file into more specific files The `test_tvmc_common.py` file was becoming a bit of a mixed bag of tests and as we now want to extend the `Target` processing logic it made sense to split each out into its own file to make it clearer what each does. `test_common.py` has also been renamed before we start using it for all the tests instead. * [TVMC] Compose target options from target registry [The RFC for this is still under discussion](https://github.com/apache/tvm-rfcs/pulls), but doing this before splitting the registries makes the most sense. This enables the `tvmc` driver to re-combobulate Target options from arguments: ``` tvmc --target=llvm \ --target-llvm-mcpu=cortex-m3 ``` --- include/tvm/target/target_kind.h | 6 ++ python/tvm/driver/tvmc/autotuner.py | 17 +++-- python/tvm/driver/tvmc/common.py | 45 +++++++++-- python/tvm/driver/tvmc/compiler.py | 15 ++-- python/tvm/driver/tvmc/target.py | 74 +++++++++++++++++++ python/tvm/target/target.py | 5 ++ src/target/target_kind.cc | 10 +++ tests/cpp/target_test.cc | 12 ++- tests/python/driver/tvmc/test_compiler.py | 4 +- tests/python/driver/tvmc/test_mlf.py | 4 +- .../python/driver/tvmc/test_target_options.py | 71 ++++++++++++++++++ 11 files changed, 237 insertions(+), 26 deletions(-) create mode 100644 python/tvm/driver/tvmc/target.py create mode 100644 tests/python/driver/tvmc/test_target_options.py diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 9d8695a43aff..e802a3088d2d 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -229,6 +229,12 @@ class TargetKindRegEntry { * \return The entry names. */ TVM_DLL static Array ListTargetKinds(); + /*! + * \brief Get all supported option names and types for a given Target kind. + * \return Map of option name to type + */ + TVM_DLL static Map ListTargetKindOptions(const TargetKind& kind); + /*! * \brief Register or get a new entry. * \param target_kind_name The name of the TargetKind. diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index dab855abfb11..92d13a99acd5 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -21,7 +21,7 @@ import logging import time from copy import deepcopy -from typing import Optional, Dict, List, Union +from typing import Any, Optional, Dict, List, Union from urllib.parse import urlparse @@ -38,6 +38,7 @@ from .common import TVMCException from .main import register_parser from .model import TVMCModel +from .target import generate_target_args, reconstruct_target_args # pylint: disable=invalid-name @@ -106,16 +107,14 @@ def add_tune_parser(subparsers): help="hostname (required) and port (optional, defaults to 9090) of the RPC tracker, " "e.g. '192.168.0.100:9999'", ) - parser.add_argument( - "--target", - help="compilation target as plain string, inline JSON or path to a JSON file", - required=True, - ) + + generate_target_args(parser) parser.add_argument( "--target-host", help="the host compilation target, defaults to 'llvm'", default="llvm", ) + parser.add_argument("--timeout", type=int, default=10, help="compilation timeout, in seconds") parser.add_argument( "--trials", @@ -286,6 +285,7 @@ def drive_tune(args): hardware_params=hardware_params, include_simple_tasks=args.include_simple_tasks, log_estimated_latency=args.log_estimated_latency, + additional_target_options=reconstruct_target_args(args), ) @@ -311,6 +311,7 @@ def tune_model( hardware_params: Optional[HardwareParams] = None, include_simple_tasks: bool = False, log_estimated_latency: bool = False, + additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, ): """Use tuning to automatically optimize the functions in a model. @@ -367,13 +368,15 @@ def tune_model( the autoscheduler. log_estimated_latency : bool, optional If using the autoscheduler, write the estimated latency at each step of tuning to file. + additional_target_options: Optional[Dict[str, Dict[str, Any]]] + Additional target options in a dictionary to combine with initial Target arguments Returns ------- tuning_records : str The path to the produced tuning log file. """ - target, extra_targets = common.target_from_cli(target) + target, extra_targets = common.target_from_cli(target, additional_target_options) target, target_host = Target.check_and_update_host_consist(target, target_host) # TODO(jwfromm) Remove this deepcopy once AlterOpLayout bug that mutates source # model is fixed. For now, creating a clone avoids the issue. diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 9ef2f6f1fbfa..f4bc3ec027d7 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -80,7 +80,7 @@ def convert_graph_layout(mod, desired_layout): ) -def validate_targets(parse_targets): +def validate_targets(parse_targets, additional_target_options=None): """ Apply a series of validations in the targets provided via CLI. """ @@ -104,6 +104,15 @@ def validate_targets(parse_targets): f"Found: {verbose_tvm_targets}." ) + if additional_target_options is not None: + for target_name in additional_target_options: + if not any([target for target in parse_targets if target["name"] == target_name]): + first_option = list(additional_target_options[target_name].keys())[0] + raise TVMCException( + f"Passed --target-{target_name}-{first_option}" + f" but did not specify {target_name} target" + ) + def tokenize_target(target): """ @@ -261,7 +270,21 @@ def is_inline_json(target): return False -def target_from_cli(target): +def _combine_target_options(target, additional_target_options=None): + if additional_target_options is None: + return target + if target["name"] in additional_target_options: + target["opts"].update(additional_target_options[target["name"]]) + return target + + +def _recombobulate_target(target): + name = target["name"] + opts = " ".join([f"-{key}={value}" for key, value in target["opts"].items()]) + return f"{name} {opts}" + + +def target_from_cli(target, additional_target_options=None): """ Create a tvm.target.Target instance from a command line interface (CLI) string. @@ -272,6 +295,10 @@ def target_from_cli(target): compilation target as plain string, inline JSON or path to a JSON file + additional_target_options: Optional[Dict[str, Dict[str,str]]] + dictionary of additional target options to be + combined with parsed targets + Returns ------- tvm.target.Target @@ -298,18 +325,22 @@ def target_from_cli(target): except ValueError as ex: raise TVMCException(f"Error parsing target string '{target}'.\nThe error was: {ex}") - validate_targets(parsed_targets) - tvm_targets = [t for t in parsed_targets if t["is_tvm_target"]] + validate_targets(parsed_targets, additional_target_options) + tvm_targets = [ + _combine_target_options(t, additional_target_options) + for t in parsed_targets + if t["is_tvm_target"] + ] # Validated target strings have 1 or 2 tvm targets, otherwise # `validate_targets` above will fail. if len(tvm_targets) == 1: - target = tvm_targets[0]["raw"] + target = _recombobulate_target(tvm_targets[0]) target_host = None else: assert len(tvm_targets) == 2 - target = tvm_targets[0]["raw"] - target_host = tvm_targets[1]["raw"] + target = _recombobulate_target(tvm_targets[0]) + target_host = _recombobulate_target(tvm_targets[1]) extra_targets = [t for t in parsed_targets if not t["is_tvm_target"]] diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 9eb85a4934cb..7623a141c27a 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -19,7 +19,7 @@ """ import logging import os.path -from typing import Optional, Dict, List, Union, Callable +from typing import Any, Optional, Dict, List, Union, Callable from pathlib import Path import tvm @@ -30,6 +30,7 @@ from . import common, composite_target, frontends from .model import TVMCModel, TVMCPackage from .main import register_parser +from .target import generate_target_args, reconstruct_target_args # pylint: disable=invalid-name @@ -91,11 +92,7 @@ def add_compile_parser(subparsers): "times, each one to set one configuration value, " "e.g. '--pass-config relay.backend.use_auto_scheduler=0'.", ) - parser.add_argument( - "--target", - help="compilation targets as comma separated string, inline JSON or path to a JSON file.", - required=True, - ) + generate_target_args(parser) parser.add_argument( "--tuning-records", metavar="PATH", @@ -154,6 +151,7 @@ def drive_compile(args): desired_layout=args.desired_layout, disabled_pass=args.disabled_pass, pass_context_configs=args.pass_config, + additional_target_options=reconstruct_target_args(args), ) return 0 @@ -172,6 +170,7 @@ def compile_model( desired_layout: Optional[str] = None, disabled_pass: Optional[str] = None, pass_context_configs: Optional[List[str]] = None, + additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, ): """Compile a model from a supported framework into a TVM module. @@ -215,6 +214,8 @@ def compile_model( pass_context_configs: list[str], optional List of strings containing a set of configurations to be passed to the PassContext. + additional_target_options: Optional[Dict[str, Dict[str, Any]]] + Additional target options in a dictionary to combine with initial Target arguments Returns @@ -230,7 +231,7 @@ def compile_model( if desired_layout: mod = common.convert_graph_layout(mod, desired_layout) - tvm_target, extra_targets = common.target_from_cli(target) + tvm_target, extra_targets = common.target_from_cli(target, additional_target_options) tvm_target, target_host = Target.check_and_update_host_consist(tvm_target, target_host) for codegen_from_cli in extra_targets: diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py new file mode 100644 index 000000000000..7a078b8be087 --- /dev/null +++ b/python/tvm/driver/tvmc/target.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This file contains functions for processing target inputs for the TVMC CLI +""" + +from tvm.target import Target + +# We can't tell the type inside an Array but all current options are strings so +# it can default to that. Bool is used alongside Integer but aren't distinguished +# between as both are represented by IntImm +INTERNAL_TO_NATIVE_TYPE = {"runtime.String": str, "IntImm": int, "Array": str} +INTERNAL_TO_HELP = {"runtime.String": " string", "IntImm": "", "Array": " options"} + + +def _generate_target_kind_args(parser, kind): + target_group = parser.add_argument_group(f"target {kind.name}") + for target_option, target_type in kind.options.items(): + if target_type in INTERNAL_TO_NATIVE_TYPE: + target_group.add_argument( + f"--target-{kind.name}-{target_option}", + type=INTERNAL_TO_NATIVE_TYPE[target_type], + help=f"target {kind.name} {target_option}{INTERNAL_TO_HELP[target_type]}", + ) + + +def generate_target_args(parser): + """Walks through the TargetKind registry and generates arguments for each Target's options""" + parser.add_argument( + "--target", + help="compilation target as plain string, inline JSON or path to a JSON file", + required=True, + ) + target_kinds = Target.list_kinds() + for target_kind in target_kinds: + target = Target(target_kind) + _generate_target_kind_args(parser, target.kind) + + +def _reconstruct_target_kind_args(args, kind): + kind_options = {} + for target_option, target_type in kind.options.items(): + if target_type in INTERNAL_TO_NATIVE_TYPE: + var_name = f"target_{kind.name}_{target_option.replace('-', '_')}" + option_value = getattr(args, var_name) + if option_value is not None: + kind_options[target_option] = getattr(args, var_name) + return kind_options + + +def reconstruct_target_args(args): + """Reconstructs the target options from the arguments""" + target_kinds = Target.list_kinds() + reconstructed = {} + for target_kind in target_kinds: + target = Target(target_kind) + kind_options = _reconstruct_target_kind_args(args, target.kind) + if kind_options: + reconstructed[target.kind.name] = kind_options + return reconstructed diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 4e5826f5b2a2..9af09296e9cc 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -31,6 +31,11 @@ class TargetKind(Object): """Kind of a compilation target""" + @property + def options(self): + """Returns the dict of available option names and types""" + return dict(_ffi_api.ListTargetKindOptions(self)) + @tvm._ffi.register_object class Target(Object): diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index d719386d204b..7cd329f83738 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -49,6 +49,14 @@ Array TargetKindRegEntry::ListTargetKinds() { return TargetKindRegistry::Global()->ListAllNames(); } +Map TargetKindRegEntry::ListTargetKindOptions(const TargetKind& target_kind) { + Map options; + for (const auto& kv : target_kind->key2vtype_) { + options.Set(kv.first, kv.second.type_key); + } + return options; +} + TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const String& target_kind_name) { return TargetKindRegistry::Global()->RegisterOrGet(target_kind_name); } @@ -359,5 +367,7 @@ TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option>("de /********** Registry **********/ TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds); +TVM_REGISTER_GLOBAL("target.ListTargetKindOptions") + .set_body_typed(TargetKindRegEntry::ListTargetKindOptions); } // namespace tvm diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 2e8ba11c0262..6106eb2225e1 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -152,8 +152,18 @@ TEST(TargetCreation, DeduplicateKeys) { ICHECK_EQ(target->GetAttr("link-params"), false); } -TEST(TargetKindRegistryListTargetKinds, Basic) { +TEST(TargetKindRegistry, ListTargetKinds) { Array names = TargetKindRegEntry::ListTargetKinds(); ICHECK_EQ(names.empty(), false); ICHECK_EQ(std::count(std::begin(names), std::end(names), "llvm"), 1); } + +TEST(TargetKindRegistry, ListTargetOptions) { + TargetKind llvm = TargetKind::Get("llvm").value(); + Map attrs = TargetKindRegEntry::ListTargetKindOptions(llvm); + ICHECK_EQ(attrs.empty(), false); + + ICHECK_EQ(attrs["mattr"], "Array"); + ICHECK_EQ(attrs["mcpu"], "runtime.String"); + ICHECK_EQ(attrs["system-lib"], "IntImm"); +} diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 6e57796d1cbf..9d44d8f22f41 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -397,7 +397,7 @@ def test_compile_tflite_module_with_external_codegen_cmsisnn( tvmc_package = tvmc.compiler.compile_model( tvmc_model, - target=f"cmsis-nn, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 --executor=aot", + target=f"cmsis-nn, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 -executor=aot", output_format="mlf", package_path=output_file_name, pass_context_configs=["tir.disable_vectorize=true"], @@ -455,7 +455,7 @@ def test_compile_tflite_module_with_external_codegen_ethosu( tvmc_package = tvmc.compiler.compile_model( tvmc_model, - target=f"ethos-u -accelerator_config={accel_type}, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 --executor=aot", + target=f"ethos-u -accelerator_config={accel_type}, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 -executor=aot", output_format="mlf", package_path=output_file_name, pass_context_configs=["tir.disable_vectorize=true"], diff --git a/tests/python/driver/tvmc/test_mlf.py b/tests/python/driver/tvmc/test_mlf.py index 0426f5678153..11306bd58848 100644 --- a/tests/python/driver/tvmc/test_mlf.py +++ b/tests/python/driver/tvmc/test_mlf.py @@ -27,7 +27,7 @@ @pytest.mark.parametrize( - "target,pass_configs", [["llvm", []], ["c --executor=aot", ["tir.disable_vectorize=1"]]] + "target,pass_configs", [["llvm", []], ["c -executor=aot", ["tir.disable_vectorize=1"]]] ) def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory, target, pass_configs): pytest.importorskip("tflite") @@ -114,7 +114,7 @@ def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_compile tflite_compiled_model_mlf = tflite_compile_model( tflite_mobilenet_v1_1_quant, - target="c --executor=aot", + target="c -executor=aot", output_format="mlf", pass_context_configs=["tir.disable_vectorize=1"], ) diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py new file mode 100644 index 000000000000..f6942299b751 --- /dev/null +++ b/tests/python/driver/tvmc/test_target_options.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse + +import pytest + +from tvm.driver import tvmc +from tvm.driver.tvmc.common import TVMCException +from tvm.driver.tvmc.target import generate_target_args, reconstruct_target_args + + +def test_target_to_argparse(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, _ = parser.parse_known_args( + ["--target=llvm", "--target-llvm-mattr=+fp,+mve", "--target-llvm-mcpu=cortex-m3"] + ) + assert parsed.target == "llvm" + assert parsed.target_llvm_mcpu == "cortex-m3" + assert parsed.target_llvm_mattr == "+fp,+mve" + + +def test_mapping_target_args(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, _ = parser.parse_known_args(["--target=llvm", "--target-llvm-mcpu=cortex-m3"]) + assert reconstruct_target_args(parsed) == {"llvm": {"mcpu": "cortex-m3"}} + + +def test_target_recombobulation_single(): + tvm_target, _ = tvmc.common.target_from_cli("llvm", {"llvm": {"mcpu": "cortex-m3"}}) + + assert str(tvm_target) == "llvm -keys=cpu -link-params=0 -mcpu=cortex-m3" + + +def test_target_recombobulation_many(): + tvm_target, _ = tvmc.common.target_from_cli( + "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu", + {"llvm": {"mcpu": "cortex-m3"}, "opencl": {"max_num_threads": 404}}, + ) + + assert "-max_num_threads=404" in str(tvm_target) + assert "-device=mali" in str(tvm_target) + assert "-mtriple=aarch64-linux-gnu" in str(tvm_target.host) + assert "-mcpu=cortex-m3" in str(tvm_target.host) + + +def test_error_if_target_missing(): + with pytest.raises( + TVMCException, + match="Passed --target-opencl-max_num_threads but did not specify opencl target", + ): + tvmc.common.target_from_cli( + "llvm", + {"opencl": {"max_num_threads": 404}}, + ) From 80beda79066c8a1fba12bb8847d292e2f27e5c5c Mon Sep 17 00:00:00 2001 From: Adam Straw Date: Wed, 13 Oct 2021 06:25:48 -0700 Subject: [PATCH 21/84] Hexagon conv2d full output slice (#9198) * split h axis by constant factor 2; no cache write * enable cache_write, but not yet able to compute_at * cache_write with compute_at * cleanup, make loop split semantics more clear * parameterize height loop split * nhwhwc wiggling (needs cleanup) * added input channel splits for crouton depth * cleanup variable names and magic numbers * comments * add README * added 3x3 conv2d (no padding) case * add ASF header and RFC link * cleanup README --- tests/python/contrib/test_hexagon/README.md | 431 ++++++++++++++++++ .../test_hexagon/test_conv2d_blocked.py | 117 ++++- 2 files changed, 529 insertions(+), 19 deletions(-) create mode 100644 tests/python/contrib/test_hexagon/README.md diff --git a/tests/python/contrib/test_hexagon/README.md b/tests/python/contrib/test_hexagon/README.md new file mode 100644 index 000000000000..1d6a298d48d6 --- /dev/null +++ b/tests/python/contrib/test_hexagon/README.md @@ -0,0 +1,431 @@ + + + + + + + + + + + + + + + + + +Documents manual TE schedule to illustrate Hexagon operator slicing. + +# High Level Notes + +* Using float32 (for now) so that tests will pass on CPU +* Using global storage scope (for now) which means "cache" reads and writes from global, to global +* TIR is pending changes from the work-in-progress layout RFC + (https://github.com/apache/tvm-rfcs/pull/39) +* TIR has been hand-edited for context and clarity + * Added C-style comments + * Changed variable names + * Added spacing and line breaks +* Naming conventions + * Using input (instead of activation) + * Using kernel (instead of weight, filter) + * Using `k` to denote channel-out and `c` or `rc` (reduction channel) to denote channel-in + * Using `rh` and `rw` (reduction height / width) to denote kernel height and width + +# Calling Convention + +TODO: Map this packed string to parameters +conv2d_packed_filter-1-1-0-float32-1-1-64-64-64-llvm + +# Baseline conv2d + +This is a baseline 1x1 conv2d schedule for Hexagon. + +## Command + +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-1-1-64-64-64-llvm]" + +## Parameters + +| Parameter | Value | +| --------- | ----------- | +| Batch | 1 | +| Kernel | 1x1 | +| Spatial | 64x64 | +| Input Ch | 64 | +| Output Ch | 64 | +| Stride | 1 | +| Padding | 0 | +| Layout | NHWC8h8w32c | + +## Assumptions + +* Microkernels will compute "full depth" in channel-out (k) dimension. + * The compute schedule (see TIR below) + * Places the outer channel-out loop over `ko` inside the outer width loop over `wo` + * Encodes the assumption that Hexagon microkernels will compute "full depth" in the channel-out (k) dimension + +## To Do + +* Adjust compute schedule and add kernel cache read once Hexagon microkernel semantics are understood + +## Annotated TIR + +``` +primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> () + attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 2, 8, 8, 32], []), // NHWC8h8w32c + kernel_buffer: Buffer(kernel_pointer: Pointer(float32), float32, [2, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending layout RFC) + buffer_map = {input_handle: input_buffer, kernel_handle: kernel_buffer, output_handle: output_buffer} { + + allocate(input.cache: Pointer(global float32), float32, [32768]), storage_scope = global; + allocate(output.cache: Pointer(global float32), float32, [32768]), storage_scope = global; + + for (ho.outer: int32, 0, 8) { + // cache read + // NHWC -> NHWC8h8w32c (pending layout RFC) + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[(((((wo*4096) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[((((((ho.outer*32768) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } + + // compute + for (wo.c: int32, 0, 8) { + for (ko.c: int32, 0, 2) { + + // init output cache + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[(((((wo.c*4096) + (ko.c*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } + } + } + + // convolution + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[(((((wo.c*4096) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + ( + (float32*)output.cache[(((((wo.c*4096) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[(((((wo.c*4096) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * + (float32*)kernel_pointer[(((((ko.c*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) + ) + } + } + } + } + } // end rc.outer + } // end ko.c + } // end wo.c + + // cache write + for (wo: int32, 0, 8) { + for (ko: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[((((((ho.outer*32768) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[(((((wo*4096) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] + } + } + } + } + } + } +} +``` + +# Split on Height - "Full Output Slice" + +Adds a new parameter `h_split` which creates a loop split on the height `h` dimension. The cache reads and writes are moved to the outer of the two loops created by that split - the loop over `ho.outer`. This increases cache usage by a factor equivalent to `h_split`. The compute is still "full width" and "full depth" in the channel-out dimension and now over multiple slices in the height `h` dimension. + +The key changes in TIR versus the baseline are ... + +1) Increased cache allocations: + +``` + allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; +``` + +2) The loop split on the `h` dimension: + +``` + for (ho.outer: int32, 0, 4) { + for (ho.inner: int32, 0, 2) { +``` + +## Command + +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-2-1-64-64-64-llvm]" + +## Parameters + +| Parameter | Value | +| --------- | ----------- | +| Batch | 1 | +| Kernel | 1x1 | +| Spatial | 64x64 | +| Input Ch | 64 | +| Output Ch | 64 | +| Stride | 1 | +| Padding | 0 | +| Layout | NHWC8h8w32c | +| h_split | 2 | + +## Assumptions + +Same as baseline + +## To Do + +Same as baseline + +## Annotated TIR + +``` +primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> () + attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 2, 8, 8, 32], []), + kernel_buffer: Buffer(kernel_pointer: Pointer(float32), float32, [2, 2, 1, 1, 8, 32, 4], []), + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} + buffer_map = {input_handle: input_buffer, kernel_handle: kernel_buffer, output_handle: output_buffer} { + + // increased cache usage due to h_split parameter + allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // loop split ho.outer vs. ho.inner based on h_split parameter + for (ho.outer: int32, 0, 4) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } + } + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + for (ko.c: int32, 0, 2) { + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } + } + } + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + ( + (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * + (float32*)kernel_pointer[(((((ko.c*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) + ) + } + } + } + } + } + } + } + } + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (ko: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] + } + } + } + } + } + } + } +} +``` + +# 3x3 conv2d (no padding) + +Change from a 1x1 kernel to a 3x3 kernel. The implication of this change is that `h_split + 1` rather than just `h_split` "full width" slices of the input are required to compute the output. This is due to the fact that the 3x3 kernel will "fall off the bottom" of the input and thus the vertically adjacent "full width" slice must be prefetched into the input cache. + +The key changes in TIR versus the above are... + +1) Increased input cache size to hold the vertically adjacent slice + +``` + allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; +``` + +2) Loop over `ho.inner` upper bound increased from `h_split` = 2 to `h_split + 1` = 3 + +``` + for (ho.outer: int32, 0, 4) { + for (ho.inner: int32, 0, 3) { + if (((ho.outer*2) + ho.inner) < 8) { +``` + +The `if` statement above indicates NOT to prefetch the vertically adjacent slice at the "bottom" of the input since it does not exist. + +## Command + +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-3-1-0-float32-2-1-64-64-64-llvm]" + +## Parameters + +| Parameter | Value | +| --------- | ----------- | +| Batch | 1 | +| Kernel | 3x3 | +| Spatial | 64x64 | +| Input Ch | 64 | +| Output Ch | 64 | +| Stride | 1 | +| Padding | 0 | +| Layout | NHWC8h8w32c | +| h_split | 2 | + +## Assumptions + +Same as above + +## To Do + +Same as above, and ... + +There may be some opportunity to optimize cache reuse in this case. Consider the loops over `ho.outer` and `ho.inner` and the index calculation `ho.outer * 64k + ho.inner * 32k` into the input pointer: + +| ho.outer | ho.inner | ho.outer * 64k + ho.inner * 32k | +| -------- | -------- | ------------------------------------- | +| 0 | 0 | 0 | +| 0 | 1 | 32k | +| 0 | 2 | 64k (vertical adjacent slice loop 0) | +| 1 | 0 | 64k | +| 1 | 1 | 96k | +| 1 | 2 | 128k (vertical adjacent slice loop 1) | +| 2 | 0 | 128k | +| 2 | 1 | 160k | +| 2 | 2 | 192k (vertical adjacent slice loop 2) | +| 3 | 0 | 192k | +| 3 | 1 | 224k | +| 3 | 2 | (No vertical adjacent slice loop 3) | + +Noe that the vertically adjacent slice in loop N (i.e. the loop where `ho.outer` = N) is reused in loop N + 1. + +## Annotated TIR + +``` +primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> () + attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 2, 8, 8, 32], []), + kernel_buffer: Buffer(kernel_pointer: Pointer(float32), float32, [2, 2, 3, 3, 8, 32, 4], []), + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} + buffer_map = {input_handle: input_buffer, kernel_handle: kernel_buffer, output_handle: output_buffer} { + + // increased input cache size to hold vertically adjacent slice + allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; + allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + for (ho.outer: int32, 0, 4) { + + // iterate over h_split + 1 = 3 input slices + for (ho.inner: int32, 0, 3) { + + // don't prefetch the vertically adjacent slice at the "bottom" of the input + if (((ho.outer*2) + ho.inner) < 8) { + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } + } + } + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + for (ko.c: int32, 0, 2) { + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } + } + } + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + ( + (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[((((((((floordiv((hi.c + rh), 8)*32768) + (ho.c.inner*32768)) + (floordiv((wi.c + rw), 8)*4096)) + (wo.c*4096)) + (rc.outer*2048)) + (floormod((hi.c + rh), 8)*256)) + (floormod((wi.c + rw), 8)*32)) + rc.inner)] * + (float32*)kernel_pointer[(((((((ko.c*18432) + (rc.outer*9216)) + (rh*3072)) + (rw*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) + ) + } + } + } + } + } + } + } + } + } + } + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (ko: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] + } + } + } + } + } + } + } +} +``` \ No newline at end of file diff --git a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py b/tests/python/contrib/test_hexagon/test_conv2d_blocked.py index e0b7fb20ab8e..37a623b613f8 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py +++ b/tests/python/contrib/test_hexagon/test_conv2d_blocked.py @@ -162,6 +162,7 @@ def conv2d_packed_filter( stride, padding, dtype, + h_split_factor, storage_scope="global", ): """ @@ -260,15 +261,45 @@ def compute(n, ho, wo, ko, hi, wi, ki): s[X_pad].compute_inline() s[X_packed].compute_inline() - # Perform scheduling - n, hid, wid, cid, hoff, woff, coff = s[Y].op.axis - slice = s[Y].fuse(wid, cid) + # cache read for the input / activation (X) Xl = s.cache_read(X_packed, storage_scope, [Y]) + + # cache write for the output (Y) Yl = s.cache_write(Y, storage_scope) - s[Yl].compute_at(s[Y], hid) - n, hid, slice, hoff, woff, coff = s[Yl].op.axis - s[Xl].compute_at(s[Yl], slice) + ######################## + # cache write schedule # + ######################## + + # loop schedule corresponding with nhwc8h8w32c layout + # using k to represent output channel + n, ho, wo, ko, hi, wi, ki = s[Y].op.axis + + # loop split h and compute cache write at outer loop split + # to increase cache usage by factor of h_split_factor + hoo, hoi = s[Y].split(ho, factor=h_split_factor) + s[Yl].compute_at(s[Y], hoo) + + #################### + # compute schedule # + #################### + + # loop schedule corresponding with nhwc8h8w32c layout + # using k to represent output channel + n, ho, wo, ko, hi, wi, ki = s[Yl].op.axis + + # reduction axes + # using rc to represent (reduction) input channel + rh, rw, rc = s[Yl].op.reduce_axis + + # split input channel by the block size + rco, rci = s[Yl].split(rc, factor=block_C) + + # loop split h and compute cache write at outer loop split + # to increase cache usage by factor of h_split_factor + hoo, hoi = s[Yl].split(ho, factor=h_split_factor) + s[Yl].reorder(n, hoo, hoi, wo, ko, rco, hi, wi, ki, rci) + s[Xl].compute_at(s[Yl], hoo) binds = {} if storage_scope and storage_scope != "global": @@ -287,6 +318,7 @@ def conv2d_packed_filter_nhwhwc( stride, padding, dtype, + h_split_factor, storage_scope="global", ): """ @@ -299,7 +331,7 @@ def conv2d_packed_filter_nhwhwc( assert kernel_size == tuple(shape_oihw8i32o4i[2:4]) block_shape = get_block_shape() - block_H, block_W, _ = block_shape + block_H, block_W, block_C = block_shape shape = get_packed_activation_layout(shape_nhwc, block_shape, packed_C=False) logical_output_shape = get_conv2d_nhwc_shape( shape_nhwc, @@ -372,18 +404,62 @@ def compute(n, ho, wo, hi, wi, k): s[X_pad].compute_inline() s[X_packed].compute_inline() + # cache read for the input / activation (X) + Xl = s.cache_read(X_packed, storage_scope, [Y]) + + # cache write for the output (Y) + Yl = s.cache_write(Y, storage_scope) + + ######################## + # cache write schedule # + ######################## + + # loop schedule corresponding with nhw8h8wc layout + # using k to represent output channel n, ho, wo, hi, wi, k = s[Y].op.axis - rh, rw, rc = s[Y].op.reduce_axis - rco, rci = s[Y].split(rc, factor=32) - s[Y].reorder(n, rco, wo, ho, k, hi, wi) - Xl = s.cache_read(X_packed, storage_scope, [Y]) - s[Xl].compute_at(s[Y], rco) + # split output channel by the block size + ko, ki = s[Y].split(k, factor=block_C) + + # loop split h and compute cache write at outer loop split + # to increase cache usage by factor of h_split_factor + hoo, hoi = s[Y].split(ho, factor=h_split_factor) + s[Y].reorder(n, hoo, hoi, wo, ko, hi, wi, ki) + s[Yl].compute_at(s[Y], hoo) + + #################### + # compute schedule # + #################### + + # loop schedule corresponding with nhw8h8wc layout + # using k to represent output channel + n, ho, wo, hi, wi, k = s[Yl].op.axis + + # reduction axes + # using rc to represent (reduction) input channel + rh, rw, rc = s[Yl].op.reduce_axis + + # split output & input channel by the block size + ko, ki = s[Yl].split(k, factor=block_C) + rco, rci = s[Yl].split(rc, factor=block_C) + + # loop split h and compute cache write at outer loop split + # to increase cache usage by factor of h_split_factor + hoo, hoi = s[Yl].split(ho, factor=h_split_factor) + s[Yl].reorder(n, hoo, hoi, wo, ko, rco, hi, wi, ki, rci) + s[Xl].compute_at(s[Yl], hoo) + + ####################### + # cache read schedule # + ####################### + + # loop schedule corresponding with nhw8h8wc layout + # using k to represent output channel + n, ho, wo, hi, wi, c = s[Xl].op.axis - ko, ki = s[Y].split(k, factor=32) - s[Y].reorder(n, rco, wo, ho, ko, hi, wi, ki) - Fl = s.cache_read(filt_packed, storage_scope, [Y]) - s[Fl].compute_at(s[Y], ko) + # split intput channel by the block size + co, ci = s[Xl].split(c, factor=block_C) + s[Xl].reorder(n, ho, wo, co, hi, wi, ci) binds = {} if storage_scope and storage_scope != "global": @@ -397,13 +473,14 @@ def compute(n, ho, wo, hi, wi, k): class BaseConv2d: batch = tvm.testing.parameter(1) - in_size = tvm.testing.parameter(8, 56) + in_size = tvm.testing.parameter(8, 56, 64) in_channel = tvm.testing.parameter(64) out_channel = tvm.testing.parameter(64) - kernel = tvm.testing.parameter(3) + kernel = tvm.testing.parameter(1, 3) stride = tvm.testing.parameter(1) - pad = tvm.testing.parameter(1) + pad = tvm.testing.parameter(0, 1) dtype = tvm.testing.parameter("float32") + h_split_factor = tvm.testing.parameter(1, 2) class TestConv2dLogical(BaseConv2d): @@ -445,6 +522,7 @@ def test_conv2d( pad, dtype, target, + h_split_factor, ): inputs = [ np.random.uniform(0, 255, size=shape_nhwc).astype(dtype), @@ -465,6 +543,7 @@ def test_conv2d( stride=(stride, stride), padding=(pad, pad, pad, pad), dtype=dtype, + h_split_factor=h_split_factor, ) return output, ref_output From 9cd07e4911cf0d5d4e323c646f3f40687e1935d6 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Wed, 13 Oct 2021 07:10:25 -0700 Subject: [PATCH 22/84] [Hexagon] Add hexagon launcher to apps and add to TVM's build system (#9220) * Add USE_HEXAGON_LAUNCHER cmake configuration to build the android hexagon launcher along with a standard TVM build. * Update hexagon launcher README.md to include instructions on how to build the launcher alongside TVM. * Move Hexagon launcher into top-level apps directory. * Refactor hexagon launcher cmake directory structure and group common code into cmake/HexagonLauncher.cmake. * Address CRs from @kparzysz-quic. --- .../hexagon_launcher}/README.md | 47 +++++- .../cmake/HexagonLauncher.cmake | 61 +++++++ .../cmake/android/CMakeLists.txt | 77 +++++++++ .../cmake/hexagon/CMakeLists.txt | 84 ++++++++++ .../hexagon_launcher}/launcher_android.cc | 0 .../hexagon_launcher}/launcher_core.cc | 0 .../hexagon_launcher}/launcher_core.h | 0 .../hexagon_launcher}/launcher_hexagon.cc | 0 .../hexagon_launcher}/launcher_main.cc | 0 .../hexagon_launcher}/launcher_rpc.idl | 0 .../hexagon_launcher}/launcher_util.cc | 0 .../hexagon_launcher}/launcher_util.h | 0 cmake/config.cmake | 3 + cmake/modules/Hexagon.cmake | 81 ++++++++- src/runtime/hexagon/launcher/CMakeLists.txt | 156 ------------------ 15 files changed, 339 insertions(+), 170 deletions(-) rename {src/runtime/hexagon/launcher => apps/hexagon_launcher}/README.md (76%) create mode 100644 apps/hexagon_launcher/cmake/HexagonLauncher.cmake create mode 100644 apps/hexagon_launcher/cmake/android/CMakeLists.txt create mode 100644 apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt rename {src/runtime/hexagon/launcher => apps/hexagon_launcher}/launcher_android.cc (100%) rename {src/runtime/hexagon/launcher => apps/hexagon_launcher}/launcher_core.cc (100%) rename {src/runtime/hexagon/launcher => apps/hexagon_launcher}/launcher_core.h (100%) rename {src/runtime/hexagon/launcher => apps/hexagon_launcher}/launcher_hexagon.cc (100%) rename {src/runtime/hexagon/launcher => apps/hexagon_launcher}/launcher_main.cc (100%) rename {src/runtime/hexagon/launcher => apps/hexagon_launcher}/launcher_rpc.idl (100%) rename {src/runtime/hexagon/launcher => apps/hexagon_launcher}/launcher_util.cc (100%) rename {src/runtime/hexagon/launcher => apps/hexagon_launcher}/launcher_util.h (100%) delete mode 100644 src/runtime/hexagon/launcher/CMakeLists.txt diff --git a/src/runtime/hexagon/launcher/README.md b/apps/hexagon_launcher/README.md similarity index 76% rename from src/runtime/hexagon/launcher/README.md rename to apps/hexagon_launcher/README.md index a8a570918514..85e6897b74a3 100644 --- a/src/runtime/hexagon/launcher/README.md +++ b/apps/hexagon_launcher/README.md @@ -19,9 +19,7 @@ ## Compilation The launcher consists of two parts: part running on Hexagon, and part running -on Android. They need to be compiled separately. Since some source files are -shared between these two parts, make sure to delete all object files between -compilations. Compile the Hexagon code first. +on Android. Each component must be compiled separately. The supported Snapdragon architectures are 855, 865, and 888. @@ -33,7 +31,46 @@ The supported Snapdragon architectures are 855, 865, and 888. Android NDK can be downloaded from https://developer.android.com/ndk. Hexagon SDK is available at //developer.qualcomm.com/software/hexagon-dsp-sdk. -### Compilation of the Hexagon part +### Compilation with TVM + +Building the Hexagon launcher application as a component of the main TVM build +used for Hexagon codegen can be achieved by setting `USE_HEXAGON_LAUNCHER=ON`. +This option will compile core tvm, the android launcher binary and its corresponding +tvm_runtime, as well as the Hexagon launcher shared library and its corresponding +tvm_runtime. As described in the [Manual compilation](#Manual compilation) section +each component requires Hexagon and android dependencies. When building the launcher +along with TVM these configurations must be providing when invoking cmake. A minimal +example invocation for compiling TVM along with the Hexagon launcher is included below, + +``` +cmake -DCMAKE_MAKE_PROGRAM=make \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_CXX_FLAGS='-stdlib=libc++' \ + -DCMAKE_CXX_STANDARD=14 \ + -DUSE_LLVM=/path/to/hexagon/llvm/bin/llvm-config \ + -DUSE_HEXAGON_LAUNCHER=ON \ + -DUSE_ANDROID_TOOLCHAIN=/path/to/android-ndk/build/cmake/android.toolchain.cmake \ + -DANDROID_PLATFORM=android-28 \ + -DANDROID_ABI=arm64-v8a \ + -DUSE_HEXAGON_ARCH=v68 \ + -DUSE_HEXAGON_SDK=/path/to/hexagon/SDK \ + -DUSE_HEXAGON_TOOLCHAIN=/path/to/hexagon/Toolchain/ .. +``` + +The Hexagon launcher application is an android binary and thus requires the use +of an android toolchain for compilation. Similarly, the Hexagon tvm runtime +requires the use of the Hexagon toolchain and depends on the Hexagon SDK. The +resulting hexagon launcher binaries can be found in the `launcher` subdirectory +of the cmake build directory. + +### Manual compilation + +Since some source files are shared between the Hexagon and android builds, +make sure to delete all object files between compilations. Compile the Hexagon +code first. + +#### Compilation of the Hexagon part 1. Build the static version of TVM runtime for Hexagon. Use Hexagon clang from the Hexagon SDK. This step is the same as building the shared version, @@ -55,7 +92,7 @@ Hexagon SDK is available at //developer.qualcomm.com/software/hexagon-dsp-sdk. 3. Run `make`. This will create `liblauncher_rpc_skel.so`. -### Compilation of the Android part +#### Compilation of the Android part 1. Build TVM runtime for Android, using clang for AArch64 from the Android NDK. Unlike in the Hexagon case, this should be the dynamic library (which diff --git a/apps/hexagon_launcher/cmake/HexagonLauncher.cmake b/apps/hexagon_launcher/cmake/HexagonLauncher.cmake new file mode 100644 index 000000000000..4a7f803ce1ab --- /dev/null +++ b/apps/hexagon_launcher/cmake/HexagonLauncher.cmake @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(NOT "${FASTRPC_LIBS}" STREQUAL "SKEL" AND + NOT "${FASTRPC_LIBS}" STREQUAL "STUB") + message(SEND_ERROR "Please set FASTRPC_LIBS to either SKEL or STUB") +endif() + +if(NOT DEFINED USE_HEXAGON_SDK) + message(SEND_ERROR "Please set USE_HEXAGON_SDK to the location of Hexagon SDK") +endif() +if (NOT DEFINED USE_HEXAGON_ARCH) + message(SEND_ERROR "Please set USE_HEXAGON_ARCH to the Hexagon architecture version") +endif() + +set(TVM_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../../") + +include(ExternalProject) +include("${TVM_SOURCE_DIR}/cmake/modules/HexagonSDK.cmake") + +find_hexagon_sdk_root("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}") + +include_directories(SYSTEM ${HEXAGON_SDK_INCLUDES} ${HEXAGON_REMOTE_ROOT}) + +set(QAIC_EXE "${HEXAGON_QAIC_EXE}") +foreach(INCDIR IN LISTS HEXAGON_SDK_INCLUDES HEXAGON_REMOTE_ROOT) + list(APPEND QAIC_FLAGS "-I${INCDIR}") +endforeach() + +set(LAUNCHER_SRC "${CMAKE_CURRENT_SOURCE_DIR}/../../") +set(CMAKE_SKIP_RPATH TRUE) + +# Qaic for the domain header. +# +# Don't add paths to these filenames, or otherwise cmake may spontaneously +# add -o option to the qaic invocation (with an undesirable path). +set(LAUNCHER_RPC_IDL "launcher_rpc.idl") +set(LAUNCHER_RPC_H "launcher_rpc.h") +set(LAUNCHER_RPC_SKEL_C "launcher_rpc_skel.c") +set(LAUNCHER_RPC_STUB_C "launcher_rpc_stub.c") + +include_directories( + "${LAUNCHER_SRC}" + "${TVM_SOURCE_DIR}/include" + "${TVM_SOURCE_DIR}/3rdparty/dlpack/include" + "${TVM_SOURCE_DIR}/3rdparty/dmlc-core/include" +) diff --git a/apps/hexagon_launcher/cmake/android/CMakeLists.txt b/apps/hexagon_launcher/cmake/android/CMakeLists.txt new file mode 100644 index 000000000000..c000b0e97cad --- /dev/null +++ b/apps/hexagon_launcher/cmake/android/CMakeLists.txt @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.2) +project(HexagonAndroidLauncher C CXX) + +include("${CMAKE_CURRENT_SOURCE_DIR}/../HexagonLauncher.cmake") + +add_custom_command( + OUTPUT ${LAUNCHER_RPC_STUB_C} + COMMAND ${QAIC_EXE} ${QAIC_FLAGS} + "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" + COMMAND ${CMAKE_COMMAND} -E rename "${LAUNCHER_RPC_H}" + "${LAUNCHER_SRC}/${LAUNCHER_RPC_H}" + MAIN_DEPENDENCY "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" +) + +include_directories(SYSTEM + "${HEXAGON_SDK_INCLUDES}" + "${HEXAGON_RPCMEM_ROOT}/inc" +) + +link_directories(${HEXAGON_REMOTE_ROOT}) + +add_definitions(-DDMLC_USE_LOGGING_LIBRARY=) + +set(STUB_SRCS + "${LAUNCHER_SRC}/launcher_android.cc" + "${LAUNCHER_SRC}/launcher_core.cc" + "${LAUNCHER_SRC}/launcher_main.cc" + "${LAUNCHER_SRC}/launcher_util.cc" +) + +add_executable(launcher_android + "${STUB_SRCS}" + "${LAUNCHER_RPC_STUB_C}" +) + +ExternalProject_Add(android_tvm_runtime + SOURCE_DIR "${TVM_SOURCE_DIR}" + BUILD_COMMAND $(MAKE) runtime + CMAKE_ARGS + "-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE}" + "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" + "-DANDROID_ABI=${ANDROID_ABI}" + "-DCMAKE_CXX_STANDARD=14" + "-DUSE_LIBBACKTRACE=OFF" + "-DUSE_LLVM=OFF" + "-DUSE_RPC=OFF" + INSTALL_COMMAND "" + BUILD_ALWAYS ON +) +ExternalProject_Get_Property(android_tvm_runtime BINARY_DIR) +ExternalProject_Add_Step(android_tvm_runtime copy_binaries + COMMAND ${CMAKE_COMMAND} -E copy_directory ${BINARY_DIR} ${CMAKE_INSTALL_PREFIX} + DEPENDEES install +) + +add_dependencies(launcher_android android_tvm_runtime) +add_library(tvm_runtime SHARED IMPORTED) +set_target_properties(tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.so") + +target_link_libraries(launcher_android cdsprpc log tvm_runtime) diff --git a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt new file mode 100644 index 000000000000..c76fcccc5a1a --- /dev/null +++ b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.2) +project(HexagonLauncherRPCSkel C CXX) + +include("${CMAKE_CURRENT_SOURCE_DIR}/../HexagonLauncher.cmake") + +add_custom_command( + OUTPUT ${LAUNCHER_RPC_SKEL_C} ${LAUNCHER_RPC_H} + COMMAND ${QAIC_EXE} ${QAIC_FLAGS} + "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" + MAIN_DEPENDENCY "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" +) + +include_directories(SYSTEM ${HEXAGON_QURT_INCLUDES}) + +link_directories(${HEXAGON_QURT_LIBS}) + +add_definitions(-D_MACH_I32=int) +add_definitions(-DDMLC_CXX11_THREAD_LOCAL=0) +add_definitions(-DDMLC_USE_LOGGING_LIBRARY=) + +# Extra compile flags (both C and C++). +set(EXTRA_COMP_FLAGS + "-O3" + "-m${USE_HEXAGON_ARCH}" +) +string(REGEX REPLACE ";" " " EXTRA_COMP_FLAGS_STR "${EXTRA_COMP_FLAGS}") +set(CMAKE_C_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_C_FLAGS}") +set(CMAKE_CXX_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_CXX_FLAGS}") + +set(SKEL_SRCS + "${LAUNCHER_SRC}/launcher_core.cc" + "${LAUNCHER_SRC}/launcher_hexagon.cc" +) +add_library(launcher_rpc_skel SHARED + "${LAUNCHER_SRC}/${LAUNCHER_RPC_H}" + "${LAUNCHER_RPC_SKEL_C}" + "${SKEL_SRCS}" +) + +ExternalProject_Add(static_hexagon_tvm_runtime + SOURCE_DIR "${TVM_SOURCE_DIR}" + BUILD_COMMAND $(MAKE) runtime + CMAKE_ARGS + "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" + "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" + "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" + "-DCMAKE_CXX_STANDARD=14" + "-DUSE_LIBBACKTRACE=OFF" + "-DUSE_LLVM=OFF" + "-DUSE_RPC=OFF" + "-DBUILD_STATIC_RUNTIME=ON" + "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" + INSTALL_COMMAND "" + BUILD_ALWAYS ON +) +ExternalProject_Get_Property(static_hexagon_tvm_runtime BINARY_DIR) +ExternalProject_Add_Step(static_hexagon_tvm_runtime copy_binaries + COMMAND ${CMAKE_COMMAND} -E copy_directory ${BINARY_DIR} ${CMAKE_INSTALL_PREFIX} + DEPENDEES install +) + +add_dependencies(launcher_rpc_skel static_hexagon_tvm_runtime) +add_library(static_tvm_runtime STATIC IMPORTED) +set_target_properties(static_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.a") + +target_link_libraries(launcher_rpc_skel -Wl,--whole-archive static_tvm_runtime -Wl,--no-whole-archive) + diff --git a/src/runtime/hexagon/launcher/launcher_android.cc b/apps/hexagon_launcher/launcher_android.cc similarity index 100% rename from src/runtime/hexagon/launcher/launcher_android.cc rename to apps/hexagon_launcher/launcher_android.cc diff --git a/src/runtime/hexagon/launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc similarity index 100% rename from src/runtime/hexagon/launcher/launcher_core.cc rename to apps/hexagon_launcher/launcher_core.cc diff --git a/src/runtime/hexagon/launcher/launcher_core.h b/apps/hexagon_launcher/launcher_core.h similarity index 100% rename from src/runtime/hexagon/launcher/launcher_core.h rename to apps/hexagon_launcher/launcher_core.h diff --git a/src/runtime/hexagon/launcher/launcher_hexagon.cc b/apps/hexagon_launcher/launcher_hexagon.cc similarity index 100% rename from src/runtime/hexagon/launcher/launcher_hexagon.cc rename to apps/hexagon_launcher/launcher_hexagon.cc diff --git a/src/runtime/hexagon/launcher/launcher_main.cc b/apps/hexagon_launcher/launcher_main.cc similarity index 100% rename from src/runtime/hexagon/launcher/launcher_main.cc rename to apps/hexagon_launcher/launcher_main.cc diff --git a/src/runtime/hexagon/launcher/launcher_rpc.idl b/apps/hexagon_launcher/launcher_rpc.idl similarity index 100% rename from src/runtime/hexagon/launcher/launcher_rpc.idl rename to apps/hexagon_launcher/launcher_rpc.idl diff --git a/src/runtime/hexagon/launcher/launcher_util.cc b/apps/hexagon_launcher/launcher_util.cc similarity index 100% rename from src/runtime/hexagon/launcher/launcher_util.cc rename to apps/hexagon_launcher/launcher_util.cc diff --git a/src/runtime/hexagon/launcher/launcher_util.h b/apps/hexagon_launcher/launcher_util.h similarity index 100% rename from src/runtime/hexagon/launcher/launcher_util.h rename to apps/hexagon_launcher/launcher_util.h diff --git a/cmake/config.cmake b/cmake/config.cmake index ade9d5c815c1..1fce11f90aed 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -279,6 +279,9 @@ set(USE_FALLBACK_STL_MAP OFF) set(USE_HEXAGON_DEVICE OFF) set(USE_HEXAGON_SDK /path/to/sdk) +# Whether to build the hexagon launcher +set(USE_HEXAGON_LAUNCHER OFF) + # Hexagon architecture to target when compiling TVM itself (not the target for # compiling _by_ TVM). This applies to components like the TVM runtime, but is # also used to select correct include/library paths from the Hexagon SDK when diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index eb3ad1f5ae4a..1491a4558611 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -53,23 +53,86 @@ if(BUILD_FOR_HEXAGON) include_directories(SYSTEM ${HEXAGON_SDK_INCLUDES} ${HEXAGON_QURT_INCLUDES}) endif() -if(USE_HEXAGON_DEVICE STREQUAL "OFF") - list(APPEND COMPILER_SRCS src/target/opt/build_hexagon_off.cc) - return() -elseif(NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}" AND - NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") - set(ERROR_MSG +if(USE_HEXAGON_LAUNCHER STREQUAL "ON") + set(USE_HEXAGON_DEVICE "${PICK_SIM}") +else() + if(USE_HEXAGON_DEVICE STREQUAL "OFF") + list(APPEND COMPILER_SRCS src/target/opt/build_hexagon_off.cc) + return() + elseif(NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}" AND + NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") + set(ERROR_MSG "USE_HEXAGON_DEVICE must be one of [${PICK_NONE}|${PICK_SIM}|${PICK_HW}]") - message(SEND_ERROR "${ERROR_MSG}") - return() + message(SEND_ERROR "${ERROR_MSG}") + return() + endif() endif() -# If USE_HEXAGON_DEVICE is set to a valid value, make sure that USE_HEXAGON_SDK + +# If USE_HEXAGON_DEVICE/LAUNCHER is set to a valid value, make sure that USE_HEXAGON_SDK # is defined. if(NOT USE_HEXAGON_SDK) message(SEND_ERROR "Please set USE_HEXAGON_SDK to the Hexagon SDK root") return() endif() +if(USE_HEXAGON_LAUNCHER STREQUAL "ON") + + if(DEFINED USE_ANDROID_TOOLCHAIN) + if(NOT DEFINED ANDROID_PLATFORM) + message(SEND_ERROR "Please set ANDROID_PLATFORM " + "when providing an Android cmake toolchain.") + endif() + if(NOT DEFINED ANDROID_ABI) + message(SEND_ERROR "Please set ANDROID_ABI " + "when providing an Android cmake toolchain.") + endif() + else() + message(SEND_ERROR "Please set USE_ANDROID_TOOLCHAIN to build the android " + " launcher for hexagon.") + endif() + + set(LAUNCHER_BINARY_DIR "${CMAKE_BINARY_DIR}/launcher") + ExternalProject_Add(launcher_android + SOURCE_DIR "${CMAKE_SOURCE_DIR}/apps/hexagon_launcher/cmake/android" + INSTALL_DIR "${LAUNCHER_BINARY_DIR}" + BUILD_ALWAYS ON + CMAKE_ARGS + "-DCMAKE_TOOLCHAIN_FILE=${USE_ANDROID_TOOLCHAIN}" + "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" + "-DANDROID_ABI=${ANDROID_ABI}" + "-DFASTRPC_LIBS=STUB" + "-DUSE_HEXAGON_ARCH=v68" + "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" + "-DCMAKE_INSTALL_PREFIX:PATH=" + INSTALL_COMMAND "" + ) + ExternalProject_Get_Property(launcher_android BINARY_DIR) + ExternalProject_Add_Step(launcher_android copy_binaries + COMMAND ${CMAKE_COMMAND} -E copy_directory ${BINARY_DIR} ${LAUNCHER_BINARY_DIR} + DEPENDEES install + ) + ExternalProject_Add(launcher_hexagon + SOURCE_DIR "${CMAKE_SOURCE_DIR}/apps/hexagon_launcher/cmake/hexagon" + INSTALL_DIR "${LAUNCHER_BINARY_DIR}" + BUILD_ALWAYS ON + CMAKE_ARGS + "-DCMAKE_C_COMPILER=${USE_HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang" + "-DCMAKE_CXX_COMPILER=${USE_HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++" + "-DFASTRPC_LIBS=SKEL" + "-DUSE_HEXAGON_ARCH=v68" + "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" + "-DCMAKE_INSTALL_PREFIX:PATH=" + INSTALL_COMMAND "" + ) + ExternalProject_Get_Property(launcher_hexagon BINARY_DIR) + ExternalProject_Add_Step(launcher_hexagon copy_binaries + COMMAND ${CMAKE_COMMAND} -E copy_directory ${BINARY_DIR} ${LAUNCHER_BINARY_DIR} + DEPENDEES install + ) + + set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES "${LAUNCHER_BINARY_DIR}") +endif() + if(USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}") find_hexagon_toolchain() message(STATUS "Hexagon toolchain: ${HEXAGON_TOOLCHAIN}") diff --git a/src/runtime/hexagon/launcher/CMakeLists.txt b/src/runtime/hexagon/launcher/CMakeLists.txt deleted file mode 100644 index d3a2f4f8161d..000000000000 --- a/src/runtime/hexagon/launcher/CMakeLists.txt +++ /dev/null @@ -1,156 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -cmake_minimum_required(VERSION 3.2) -project(HexagonLauncher C CXX) - -if(NOT "${FASTRPC_LIBS}" STREQUAL "SKEL" AND - NOT "${FASTRPC_LIBS}" STREQUAL "STUB") - message(SEND_ERROR "Please set FASTRPC_LIBS to either SKEL or STUB") -endif() - -if(NOT DEFINED USE_HEXAGON_SDK) - message(SEND_ERROR "Please set USE_HEXAGON_SDK to the location of Hexagon SDK") -endif() -if (NOT DEFINED USE_HEXAGON_ARCH) - message(SEND_ERROR "Please set USE_HEXAGON_ARCH to the Hexagon architecture version") -endif() - -include(../../../../cmake/modules/HexagonSDK.cmake) - -find_hexagon_sdk_root("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}") - -include_directories(SYSTEM ${HEXAGON_SDK_INCLUDES} ${HEXAGON_REMOTE_ROOT}) - -set(QAIC_EXE "${HEXAGON_QAIC_EXE}") -foreach(INCDIR IN LISTS HEXAGON_SDK_INCLUDES HEXAGON_REMOTE_ROOT) - list(APPEND QAIC_FLAGS "-I${INCDIR}") -endforeach() - -set(LAUNCHER_SRC "${CMAKE_CURRENT_SOURCE_DIR}") -set(CMAKE_SKIP_RPATH TRUE) - -# Qaic for the domain header. -# -# Don't add paths to these filenames, or otherwise cmake may spontaneously -# add -o option to the qaic invocation (with an undesirable path). -set(LAUNCHER_RPC_IDL "launcher_rpc.idl") -set(LAUNCHER_RPC_H "launcher_rpc.h") -set(LAUNCHER_RPC_SKEL_C "launcher_rpc_skel.c") -set(LAUNCHER_RPC_STUB_C "launcher_rpc_stub.c") - -add_custom_command( - OUTPUT ${LAUNCHER_RPC_SKEL_C} ${LAUNCHER_RPC_STUB_C} - "${LAUNCHER_SRC}/${LAUNCHER_RPC_H}" - COMMAND ${QAIC_EXE} ${QAIC_FLAGS} - "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" - COMMAND ${CMAKE_COMMAND} -E rename "${LAUNCHER_RPC_H}" - "${LAUNCHER_SRC}/${LAUNCHER_RPC_H}" - MAIN_DEPENDENCY "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" -) - - -if("${FASTRPC_LIBS}" STREQUAL "SKEL") - # Skel libraries. - # - if (NOT DEFINED TVM_RUNTIME_HEXAGON) - message(SEND_ERROR "Please set TVM_RUNTIME_HEXAGON=/path/to/libtvm_runtime.a") - endif() - - include_directories(SYSTEM ${HEXAGON_QURT_INCLUDES}) - include_directories( - "${LAUNCHER_SRC}" - "${LAUNCHER_SRC}/../../../../include" - "${LAUNCHER_SRC}/../../../../3rdparty/dlpack/include" - "${LAUNCHER_SRC}/../../../../3rdparty/dmlc-core/include" - ) - link_directories(${HEXAGON_QURT_LIBS}) - - add_definitions(-D_MACH_I32=int) - add_definitions(-DDMLC_CXX11_THREAD_LOCAL=0) - add_definitions(-DDMLC_USE_LOGGING_LIBRARY=) - - # Extra compile flags (both C and C++). - set(EXTRA_COMP_FLAGS - "-O3" - "-m${USE_HEXAGON_ARCH}" - ) - string(REGEX REPLACE ";" " " EXTRA_COMP_FLAGS_STR "${EXTRA_COMP_FLAGS}") - set(CMAKE_C_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_C_FLAGS}") - set(CMAKE_CXX_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_CXX_FLAGS}") - - set(EXTRA_LINK_FLAGS - "-lposix" - "-lqurt" - "-Wl,--export-dynamic" - "-Wl,--whole-archive ${TVM_RUNTIME_HEXAGON} -Wl,--no-whole-archive" - "-Wl,--defsym=HEAP_SIZE=0x40000000" - ) - string(REGEX REPLACE ";" " " EXTRA_LINK_FLAGS_STR "${EXTRA_LINK_FLAGS}") - - set(SKEL_SRCS - "launcher_core.cc" - "launcher_hexagon.cc" - ) - add_library(launcher_rpc_skel SHARED - "${LAUNCHER_SRC}/${LAUNCHER_RPC_H}" - "${LAUNCHER_RPC_SKEL_C}" - "${SKEL_SRCS}" - ) - - # Extra linker flags for linking shared libraries. - set_target_properties(launcher_rpc_skel PROPERTIES - LINK_FLAGS ${EXTRA_LINK_FLAGS_STR} - ) -else() - # Stub libraries. - # - if (NOT DEFINED TVM_RUNTIME_ANDROID) - message(SEND_ERROR "Please set TVM_RUNTIME_ANDROID=/path/to/libtvm_runtime.so") - endif() - - include_directories(SYSTEM - "${HEXAGON_SDK_INCLUDES}" - "${HEXAGON_RPCMEM_ROOT}/inc" - ) - include_directories( - "${LAUNCHER_SRC}" - "${LAUNCHER_SRC}/../../../../include" - "${LAUNCHER_SRC}/../../../../3rdparty/dlpack/include" - "${LAUNCHER_SRC}/../../../../3rdparty/dmlc-core/include" - ) - link_directories(${HEXAGON_REMOTE_ROOT}) - - add_definitions(-DDMLC_USE_LOGGING_LIBRARY=) - - set(STUB_SRCS - "launcher_android.cc" - "launcher_core.cc" - "launcher_main.cc" - "launcher_util.cc" - ) - - add_executable(launcher_android - "${STUB_SRCS}" - "${LAUNCHER_RPC_STUB_C}" - ) - target_link_libraries(launcher_android cdsprpc log) - - set_target_properties(launcher_android PROPERTIES - LINK_FLAGS "${TVM_RUNTIME_ANDROID}" - ) -endif() From f9caf2eeac45f2ab985a902b6d3a9904b60df4dc Mon Sep 17 00:00:00 2001 From: Robert Kimball Date: Wed, 13 Oct 2021 08:55:22 -0700 Subject: [PATCH 23/84] Propagate tvm target through graph tuning setup (#9248) * Propagate tvm target through graph tuning setup * Don't append -device if it is already present in tvm_target * Make sure target string has device tracing only * revert accidental reformat * Update per review comments * fix lint issue * Use string for tvm target * Update python/tvm/autotvm/graph_tuner/utils/traverse_graph.py Co-authored-by: Cody Yu * Cleanup tests Co-authored-by: Cody Yu --- .../autotvm/graph_tuner/base_graph_tuner.py | 2 +- .../graph_tuner/utils/traverse_graph.py | 28 +++++++++++++++---- .../test_autotvm_graph_tuner_utils.py | 27 ++++++++++++++---- 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index beb1aa03090d..25d56cf8cf02 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -166,7 +166,7 @@ def __init__( if isinstance(graph, relay.function.Function): node_dict = {} graph = bind_inputs(graph, input_shapes, dtype) - expr2graph(graph, self._target_ops, node_dict, self._node_list) + expr2graph(graph, self._target_ops, node_dict, self._node_list, target) else: raise RuntimeError("Unsupported graph type: %s" % str(type(graph))) diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index f61d34284e01..723e7fa77006 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -17,6 +17,7 @@ # pylint: disable=too-many-locals,too-many-statements,too-many-branches,protected-access """API for graph traversing.""" import threading +import re import tvm from tvm import relay, autotvm @@ -30,7 +31,7 @@ from .._base import OPT_OUT_OP -def expr2graph(expr, target_ops, node_dict, node_list): +def expr2graph(expr, target_ops, node_dict, node_list, tvm_target): """Convert relay expr to graph data structure and fetch workloads of target operators. @@ -50,6 +51,9 @@ def expr2graph(expr, target_ops, node_dict, node_list): Each node will be stored as a dictionary in the format of {"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type], "name": str, "workloads": [tuple], "topi_op": [function]} + + tvm_target : tvm.target + The TVM target object. """ # TODO(@kevinthesun, @icemelon9): Currently graph tuning pass relies on the fact # that # autotvm tasks == # ops. But this won't be true after having relay op @@ -58,12 +62,12 @@ def expr2graph(expr, target_ops, node_dict, node_list): env.reset(target_ops) # pylint: disable=not-context-manager with env: - _expr2graph_impl(expr, target_ops, node_dict, node_list) + _expr2graph_impl(expr, target_ops, node_dict, node_list, tvm_target) task_pos = 0 for node_entry in node_list: if node_entry["op"] in target_ops: task_name, args = env.task_collection[task_pos] - task = autotvm.task.create(task_name, args, target="llvm") + task = autotvm.task.create(task_name, args, target=tvm_target) node_entry["workloads"] = [task.workload] node_entry["topi_op"] = [task_name] task_pos += 1 @@ -77,7 +81,18 @@ def _infer_type(node): return entry if isinstance(node, relay.Function) else entry.body -def _expr2graph_impl(expr, target_ops, node_dict, node_list): +def _replace_device_with_tracing(target): + """This is to replace -device=XXX with -device=tracing in the tvm_target string. + It is a stand-along function for testability. + We need to have device=tracing in order to fetch the workloads, it is not used + for anything beyond that so it is safe to override the device here only.""" + target = str(target) + if "-device" in target: + return re.sub("-device=[^\\-$]+", "-device=tracing ", target).strip(" ") + return target + " -device=tracing" + + +def _expr2graph_impl(expr, target_ops, node_dict, node_list, tvm_target): """Implementation to convert relay expr to graph data structure""" def _traverse_expr(node): @@ -128,8 +143,9 @@ def _traverse_expr(node): call = relay.Call(node.op, params, node.attrs) mod = tvm.IRModule.from_expr(relay.Function(params, call)) relay.backend.compile_engine.get().clear() + tracing_target = _replace_device_with_tracing(tvm_target) build_thread = threading.Thread( - target=relay.build, args=(mod, "llvm -device=tracing", None, None) + target=relay.build, args=(mod, tracing_target, None, None) ) build_thread.start() build_thread.join() @@ -139,7 +155,7 @@ def _traverse_expr(node): elif isinstance(node, Function): # Ignore root node since it equals to input function expression if node != expr: - _expr2graph_impl(node, target_ops, node_dict, node_list) + _expr2graph_impl(node, target_ops, node_dict, node_list, tvm_target) return elif isinstance(node, TupleGetItem): in_node_idx = node_dict[node.tuple_value] diff --git a/tests/python/unittest/test_autotvm_graph_tuner_utils.py b/tests/python/unittest/test_autotvm_graph_tuner_utils.py index 3f6d3980ee28..583bd366847c 100644 --- a/tests/python/unittest/test_autotvm_graph_tuner_utils.py +++ b/tests/python/unittest/test_autotvm_graph_tuner_utils.py @@ -20,6 +20,8 @@ # helps avoid topi arithmetic operator overloading issue: # https://github.com/apache/tvm/issues/3240 # TODO: restore the file name after this issue is resolved. +import pytest + import tvm from tvm import te @@ -34,6 +36,7 @@ bind_inputs, ) from tvm.autotvm.graph_tuner._base import OPT_OUT_OP +from tvm.autotvm.graph_tuner.utils.traverse_graph import _replace_device_with_tracing from tvm.relay.expr import Call, TupleGetItem, Tuple, Var @@ -57,7 +60,7 @@ def test_has_multiple_inputs(): target_ops = [relay.op.get("nn.conv2d")] node_list = [] node_dict = {} - expr2graph(net, target_ops, node_dict, node_list) + expr2graph(net, target_ops, node_dict, node_list, tvm.target.Target("llvm")) input_names = ["data"] verify_has_multiple_inputs(node_list, 2, input_names, False) verify_has_multiple_inputs(node_list, 4, input_names, False) @@ -79,7 +82,7 @@ def _count_node(node): relay.analysis.post_order_visit(mod["main"], _count_node) - expr2graph(mod["main"], target_ops, node_dict, node_list) + expr2graph(mod["main"], target_ops, node_dict, node_list, tvm.target.Target("llvm")) assert len(node_list) == len(op_name_list) for i, item in enumerate(zip(op_name_list, node_list)): op_name, node = item @@ -103,7 +106,7 @@ def test_get_direct_ancestor(): target_ops = [relay.op.get("nn.conv2d")] node_list = [] node_dict = {} - expr2graph(net, target_ops, node_dict, node_list) + expr2graph(net, target_ops, node_dict, node_list, tvm.target.Target("llvm")) visited_dict = {} input_names = ["data"] out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names) @@ -115,7 +118,7 @@ def test_get_direct_ancestor(): net = bind_inputs(net, {"data": (1, 16, 224, 224)}) node_list = [] node_dict = {} - expr2graph(net, target_ops, node_dict, node_list) + expr2graph(net, target_ops, node_dict, node_list, tvm.target.Target("llvm")) out = get_direct_ancestor(node_list, visited_dict, target_ops, 3, input_names) assert out == [0], "Output mismatch: expecting [0] but got %s." % str(out) @@ -134,7 +137,7 @@ def test_get_in_nodes(): input_names = ["data"] node_list = [] node_dict = {} - expr2graph(net, target_ops, node_dict, node_list) + expr2graph(net, target_ops, node_dict, node_list, tvm.target.Target("llvm")) out = get_in_nodes(node_list, target_ops, input_names) expected_out = {3: [0], 4: [3, 0], 7: [4]} diff_set = set(out) ^ set(expected_out) @@ -155,6 +158,20 @@ def test_get_out_nodes(): ) +def test_target_device_replacement(): + assert _replace_device_with_tracing("cuda") == "cuda -device=tracing" + assert ( + _replace_device_with_tracing("cuda -device=some_device -libs=cudnn") + == "cuda -device=tracing -libs=cudnn" + ) + assert ( + _replace_device_with_tracing("llvm -device=arm_cpu -arg=xxx") + == "llvm -device=tracing -arg=xxx" + ) + assert _replace_device_with_tracing("llvm -device=arm_cpu") == "llvm -device=tracing" + assert _replace_device_with_tracing("llvm -device=abc, def") == "llvm -device=tracing" + + if __name__ == "__main__": test_has_multiple_inputs() test_expr2graph() From 185e2fb02eab3d812424c460af15c536509f9090 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Wed, 13 Oct 2021 09:10:30 -0700 Subject: [PATCH 24/84] [Topi] Fix direct SIMD conv2d schedule name (#9225) * feature tested * change name --- python/tvm/relay/op/strategy/arm_cpu.py | 6 +++--- python/tvm/topi/arm_cpu/conv2d.py | 16 ++++++++-------- .../topi/arm_cpu/cortex_m7/conv2d/direct_simd.py | 16 ++++++++-------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index e8731a0d6954..06dfc87038fe 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -130,9 +130,9 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): elif layout == "NHWC": if "SMLAD" in isa and kernel_layout == "HWOI": strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.conv2d_direct_simd), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_direct_simd), - name="conv2d_direct_simd.micro_dev", + wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_direct_simd), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_direct_simd), + name="conv2d_nhwc_direct_simd.micro_dev", ) elif kernel_layout == "HWIO": is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm() diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index b3af36740551..0500eb55996c 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -505,15 +505,15 @@ def _callback(op): return s -@autotvm.register_topi_compute("conv2d_direct_simd.arm_cpu") -def conv2d_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype): - """Compute conv2d with SIMD (v7e-m).""" - return direct_simd.conv2d_direct_simd_compute( +@autotvm.register_topi_compute("conv2d_nhwc_direct_simd.arm_cpu") +def conv2d_nhwc_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d_nhwc with SIMD (v7e-m).""" + return direct_simd.conv2d_nhwc_direct_simd_compute( cfg, data, kernel, strides, padding, dilation, out_dtype ) -@autotvm.register_topi_schedule("conv2d_direct_simd.arm_cpu") -def schedule_conv2d_direct_simd(cfg, outs): - """Create schedule for conv2d_direct_simd""" - return direct_simd.conv2d_direct_simd_nhwc_schedule(cfg, outs) +@autotvm.register_topi_schedule("conv2d_nhwc_direct_simd.arm_cpu") +def schedule_conv2d_nhwc_direct_simd(cfg, outs): + """Create schedule for conv2d_nhwc_direct_simd""" + return direct_simd.conv2d_nhwc_direct_simd_schedule(cfg, outs) diff --git a/python/tvm/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py b/python/tvm/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py index 307312076a7e..5ef9fd813eb2 100644 --- a/python/tvm/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py +++ b/python/tvm/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py @@ -30,7 +30,7 @@ ) -def conv2d_direct_simd(*args, **kwargs): +def conv2d_nhwc_direct_simd(*args, **kwargs): """Defines the Cortex-M7 SIMD implementation of conv2d.""" assert not kwargs, "Do not support kwargs in template function call" args = deserialize_args(args) @@ -39,17 +39,17 @@ def conv2d_direct_simd(*args, **kwargs): cfg = autotvm.get_config() args = [cfg] + args assert layout == "NHWC" - conv = conv2d_direct_simd_compute(*args) - sched = conv2d_direct_simd_nhwc_schedule(cfg, [data, kernel, conv]) + conv = conv2d_nhwc_direct_simd_compute(*args) + sched = conv2d_nhwc_direct_simd_schedule(cfg, [data, kernel, conv]) return sched, [data, kernel, conv] -conv2d_direct_simd.template_key = "direct_simd" -conv2d_direct_simd.default_data_layout = "NHWC" -conv2d_direct_simd.default_kernel_layout = "HWOI" +conv2d_nhwc_direct_simd.template_key = "direct_simd" +conv2d_nhwc_direct_simd.default_data_layout = "NHWC" +conv2d_nhwc_direct_simd.default_kernel_layout = "HWOI" -def conv2d_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): +def conv2d_nhwc_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): """Compute function for Cortex-M7 SIMD implementation of conv2d.""" assert isinstance(strides, int) or len(strides) == 2 assert isinstance(dilation, int) or len(dilation) == 2 @@ -146,7 +146,7 @@ def conv2d_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, ou return conv -def conv2d_direct_simd_nhwc_schedule(cfg, outs): +def conv2d_nhwc_direct_simd_schedule(cfg, outs): """Schedule function for Cortex-M7 SIMD implementation of conv2d.""" sched = te.create_schedule([x.op for x in outs]) From 2dc58bea779f02f689e9c9961741ac18af1f5ebf Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi <86472128+ashutosh-arm@users.noreply.github.com> Date: Wed, 13 Oct 2021 18:01:54 +0100 Subject: [PATCH 25/84] Bumping up CMSIS-NN version to be in sync with TFLu (#9247) Change-Id: I51103632f6d41652d616857f987a846ea2b22a5c --- docker/install/ubuntu_install_ethosu_driver_stack.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/install/ubuntu_install_ethosu_driver_stack.sh b/docker/install/ubuntu_install_ethosu_driver_stack.sh index 35b2b4c74b7b..c5ab20e18277 100755 --- a/docker/install/ubuntu_install_ethosu_driver_stack.sh +++ b/docker/install/ubuntu_install_ethosu_driver_stack.sh @@ -24,7 +24,7 @@ fvp_dir="/opt/arm/FVP_Corstone_SSE-300_Ethos-U55" cmake_dir="/opt/arm/cmake" ethosu_dir="/opt/arm/ethosu" ethosu_driver_ver="21.05" -cmsis_ver="5.7.0" +cmsis_ver="5.8.0" mkdir -p /opt/arm From b2065700bf0dedf3242a3aa0dc5da914b83e7a57 Mon Sep 17 00:00:00 2001 From: Hua Jiang Date: Wed, 13 Oct 2021 17:35:51 -0700 Subject: [PATCH 26/84] [Runtime] Pipeline Executor Second patch, configuration load and executor export/import. (#9108) * [pipeline executor] Add configuration load function and pipeline executor export,import function. * address review comments. * polish comments and doc string. * address review comments. * address review comments. * Change mod_idx start from 0, remove mod_idx - 1 logic. * address review comments. * polish documents. * adress review comments * address review comments. * address review comments. * polish the document. * address review comments. * address review comments. * polish comments. * Triger build. * address review comments. * address review comments. * fix grammar issue. * polish documents. * add single global binding check. * address review comments. * trigger build. --- python/tvm/contrib/pipeline_executor.py | 195 ++++++++++++++----- src/runtime/pipeline/pipeline_executor.cc | 124 ++++++++++-- src/runtime/pipeline/pipeline_executor.h | 100 +++++++++- src/runtime/pipeline/pipeline_scheduler.cc | 37 ++++ src/runtime/pipeline/pipeline_scheduler.h | 52 +++++ src/runtime/pipeline/pipeline_struct.h | 185 ++++++++++++++++++ tests/python/relay/test_pipeline_executor.py | 41 ++-- 7 files changed, 660 insertions(+), 74 deletions(-) create mode 100644 src/runtime/pipeline/pipeline_scheduler.cc create mode 100644 src/runtime/pipeline/pipeline_scheduler.h create mode 100644 src/runtime/pipeline/pipeline_struct.h diff --git a/python/tvm/contrib/pipeline_executor.py b/python/tvm/contrib/pipeline_executor.py index 36c03891d210..37b9fed8eb91 100644 --- a/python/tvm/contrib/pipeline_executor.py +++ b/python/tvm/contrib/pipeline_executor.py @@ -16,6 +16,7 @@ # under the License. """Pipeline executor that executes a series of modules in a pipeline fashion.""" import json +import os import tvm._ffi from tvm import relay from tvm.relay.transform import InferType @@ -47,13 +48,13 @@ def build(pipe_configs): ret: PipelineExecutorFactoryModule Common interface for pipeline executor factory modules. """ - mods = {} + libs = {} mod_n_configs = pipe_configs.get_config() config_len = len(mod_n_configs) string_config = [{} for _ in range(config_len)] for ir_mod, mod_config in mod_n_configs.items(): mconf = mod_config["pipeline"].copy() - mod_idx = mconf["mod_idx"] - 1 + mod_idx = mconf["mod_idx"] dev = mod_config["dev"] target = mod_config["target"] build_func = relay.build @@ -61,7 +62,7 @@ def build(pipe_configs): if "build" in mod_config and mod_config["build"]: build_func = mod_config["build"] - mod = build_func( + lib = build_func( ir_mod, target, params=mod_config["params"], @@ -72,9 +73,9 @@ def build(pipe_configs): mconf["dev"] = "{},{}".format(dev.device_type, dev.device_id) # Create a pipeline configuration. string_config[mod_idx] = mconf - mods[mod] = {"dev": dev} + libs[mod_idx] = {"lib": lib, "dev": dev} - return PipelineExecutorFactoryModule(mods, string_config) + return PipelineExecutorFactoryModule(libs, string_config) class PipelineModule(object): @@ -82,12 +83,59 @@ class PipelineModule(object): Parameters ---------- - module : PipelineExecutorFactoryModule - Common interface for pipeline executor factory modules. + module : Union[PipelineExecutorFactoryModule, Module] + Common interface for pipeline executor factory modules or Module. """ def __init__(self, module): - self.module = module.module + if isinstance(module, PipelineExecutorFactoryModule): + self.module = module.module + else: + self.module = module + # Get the packed functions from the pipeline executor. + self._get_num_outputs = self.module["get_num_outputs"] + + @property + def num_outputs(self): + """Get the number of outputs. + Returns + ------- + count : int + The number of outputs. + """ + return self._get_num_outputs() + + @staticmethod + def load_library(config_file_name): + """Import files to create a pipeline executor. + + Parameters + ---------- + config_file_name : str + Path and name of the configuration file, the configuration file contains the + disk path of the parameter file, library file, and JSON file. + """ + with open(config_file_name, "r") as file_handle: + config = file_handle.read() + config = json.loads(config) + if "load_config" not in config or "pipeline_config" not in config: + raise RuntimeError( + '"load_config" or "pipeline_config" is missing in %s' % config_file_name + ) + + # The config file used to load library, prameters, and JSON files. + with open(config["load_config"], "r") as file_handle: + load_config = file_handle.read() + + # The config file used to load pipeline compute config. + with open(config["pipeline_config"], "r") as file_handle: + pipeline_config = file_handle.read() + + # Load a PipelineExecutor from the disk files. + load_library = tvm._ffi.get_global_func("tvm.pipeline_executor.load", allow_missing=False) + module = load_library(load_config, pipeline_config) + + return PipelineModule(module) class PipelineConfig(object): @@ -139,13 +187,14 @@ def get_owner_idx(self): if isinstance(self.io_owner, PipelineConfig.ModuleWrapper): return self.io_owner.idx - return 0 + return -1 - def is_global_interface(self): - """The global interface is the interface visible to the caller which use a pipeline - executor, the global input interface is responsible for passing parameters to the - internal module interface, and the global output interface is responsible for - outputting the results computed by the pipeline executor to a caller. + def is_pipeline_executor_interface(self): + """The pipeline interface is used to interact with the caller. There are two types + of interfaces, one is 'input' another is 'output'. The pipeline input interface + is responsible for passing parameters to the internal module interface, and the + pipeline output interface is responsible for outputting the results computed by + the pipeline executor to the caller. """ return not isinstance(self.io_owner, PipelineConfig.ModuleWrapper) @@ -182,9 +231,9 @@ def check_dag_acyclic(self, start, inputs): def connect(self, binding): """Connect the current interface to the destination interface. - Correct connections are as follows: 1. global input connected to module input, - 2. module output connected to global output, 3. module output connected to - module input. + Correct connections are as follows: 1. the pipeline input connected to a module input, + 2. the module output connected to a pipeline output, 3. the module output connected to + a module input. Parameters ---------- @@ -196,31 +245,31 @@ def connect(self, binding): if self.io_owner == binding.io_owner: raise RuntimeError(f"Can not bind itself.") - if not self.is_global_interface() and self.io_type == "input": + if not self.is_pipeline_executor_interface() and self.io_type == "input": raise RuntimeError(f"Module can only bind from output interface!") if ( - not self.is_global_interface() - and not binding.is_global_interface() + not self.is_pipeline_executor_interface() + and not binding.is_pipeline_executor_interface() and binding.io_type == "output" ): raise RuntimeError(f"Can not bind module output with another module output!") if ( - not self.is_global_interface() - and binding.is_global_interface() + not self.is_pipeline_executor_interface() + and binding.is_pipeline_executor_interface() and binding.io_type == "input" ): - raise RuntimeError(f"Can not bind module output with global input!") + raise RuntimeError(f"Can not bind module output with pipeline input!") - if self.is_global_interface() and self.io_type == "output": + if self.is_pipeline_executor_interface() and self.io_type == "output": raise RuntimeError(f"Global output can not be used as binding start point.") - if self.is_global_interface() and binding.io_type != "input": + if self.is_pipeline_executor_interface() and binding.io_type != "input": raise RuntimeError(f"Global input can only bind with module input.") self.bindings.append(binding) - if not self.is_global_interface(): + if not self.is_pipeline_executor_interface(): # Check whether the data types of the source and destination are the same. if ( isinstance(binding.io_owner, PipelineConfig.ModuleWrapper) @@ -431,13 +480,16 @@ def get_config(self): for dep in binding.bindings: dep_item = {} _, dname = dep.get_name() - dep_item["mod_idx"] = dep.get_owner_idx() - dep_item["input_name"] = dname + if dep.is_pipeline_executor_interface(): + dep_item["global_output_index"] = int(dname) + else: + dep_item["mod_idx"] = dep.get_owner_idx() + dep_item["input_name"] = dname dep_conf.append(dep_item) # The value of ouput_idx start from 0. output["output_idx"] = int(binding.name) - output["dependent"] = dep_conf + output["dependencies"] = dep_conf output_conf.append(output) mconf["mod_idx"] = module.idx @@ -472,7 +524,7 @@ def dag_topology_sort(self): mlist += temp_list for mod, i in zip(mlist, range(len(mlist))): - self.mod_wrapper[mod].set_idx_name(i + 1) + self.mod_wrapper[mod].set_idx_name(i) def get_mod_idx(self, mod): # Return the module index. @@ -502,16 +554,13 @@ class PipelineExecutorFactoryModule(object): """ def __init__(self, pipeline_mods, mods_config): - mods, config = self.graph_executor_create(pipeline_mods, mods_config) - assert ( - pipeline_executor_enabled() - ), "Pipeline executor is not enabled. Please \ - re-build TVM with USE_PIPELINE_EXECUTOR=ON" - pipeline_create = tvm._ffi.get_global_func( + self.pipeline_mods = pipeline_mods + self.mods_config = mods_config + graph_executors, config = self.graph_executor_create(pipeline_mods, mods_config) + self.pipeline_create = tvm._ffi.get_global_func( "tvm.pipeline_executor.create", allow_missing=False ) - assert pipeline_create - self.module = pipeline_create(mods, config) + self.module = self.pipeline_create(graph_executors, config) def graph_executor_create(self, pipeline_mods, mod_config): """Create graph_executor list and return configuration as a json string. @@ -532,12 +581,70 @@ def graph_executor_create(self, pipeline_mods, mod_config): mod_config : str The Modudle configuration. """ + # Should store modules in the list named 'mods' in index order. + mods = [None for _ in range(len(pipeline_mods))] + for lib_index in pipeline_mods: + pipeline_lib = pipeline_mods[lib_index]["lib"] + dev = pipeline_mods[lib_index]["dev"] + lib = graph_executor.GraphModule(pipeline_lib["default"](dev)) + # Return a module list sorted by lib_index. + mods[lib_index] = lib.module + + return mods, json.dumps(mod_config) + + def export_library(self, directory_path): + """Export the pipeline executor into disk files. - mods = [] - for pipeline_mod in pipeline_mods: - mod = graph_executor.GraphModule( - pipeline_mod["default"](pipeline_mods[pipeline_mod]["dev"]) + Parameters + ---------- + directory_path : str + Export the files to this directory. + """ + if not self.pipeline_mods: + raise RuntimeError(f"The pipeline executor has not been initialized.") + + # Check if the directory_path exists. + if not os.path.exists(directory_path): + raise RuntimeError(f"The directory {directory_path} does not exist.") + # Create an load configuration. + load_config_file_name = "{}/load_config".format(directory_path) + pipeline_config_file_name = "{}/pipeline_config".format(directory_path) + config = {} + config["load_config"] = load_config_file_name + config["pipeline_config"] = pipeline_config_file_name + load_config = [] + # Export the library, JSON, and parameter into files, then export these files path + # into a configuration file. + for lib_index in self.pipeline_mods: + mconfig = {} + mconfig["mod_idx"] = lib_index + mconfig["lib_name"] = "{}/lib{}.so".format(directory_path, lib_index) + mconfig["json_name"] = "{}/json{}".format(directory_path, lib_index) + mconfig["params_name"] = "{}/params{}".format(directory_path, lib_index) + mconfig["dev"] = "{},{}".format( + self.pipeline_mods[lib_index]["dev"].device_type, + self.pipeline_mods[lib_index]["dev"].device_id, ) - mods.append(mod.module) - return mods, json.dumps(mod_config) + # Get the graph, lib, and parameters from GraphExecutorFactoryModule. + graph, lib, params = self.pipeline_mods[lib_index]["lib"] + # Export the lib, graph, and parameters to disk. + lib.export_library(mconfig["lib_name"]) + with open(mconfig["json_name"], "w") as file_handle: + file_handle.write(graph) + with open(mconfig["params_name"], "wb") as file_handle: + file_handle.write(relay.save_param_dict(params)) + + load_config.append(mconfig) + + with open(load_config_file_name, "w") as file_handle: + json.dump(load_config, file_handle) + + with open(pipeline_config_file_name, "w") as file_handle: + json.dump(self.mods_config, file_handle) + + config_file_name = "{}/config".format(directory_path) + with open(config_file_name, "w") as file_handle: + json.dump(config, file_handle) + + return config_file_name diff --git a/src/runtime/pipeline/pipeline_executor.cc b/src/runtime/pipeline/pipeline_executor.cc index 41f867057282..3820ce942af0 100644 --- a/src/runtime/pipeline/pipeline_executor.cc +++ b/src/runtime/pipeline/pipeline_executor.cc @@ -21,31 +21,129 @@ * \file pipeline_executor.cc */ #include "pipeline_executor.h" - namespace tvm { namespace runtime { +/*! + * \brief Give frontends an access to packed functions. + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding packed function. + */ +PackedFunc PipelineExecutor::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + if (name == "get_num_outputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc(); + } + return nullptr; +} -void PipelineRuntime::Init(const Array& modules, - const std::string& pipeline_json) { - return; +/*! + * \brief Use the mod_config information to create a graph runtime list. + * \param mod_config The config information that generates by the export library function call. + */ +std::vector PipelineExecutor::CreateGraphModules(const ModuleConfig& mod_config) { + const PackedFunc* graph_executor_create = Registry::Get("tvm.graph_executor.create"); + std::vector ret; + ret.resize(mod_config.size()); + for (auto config : mod_config) { + // Load library. + auto lib = Module::LoadFromFile(config.second.lib_name.c_str()); + + // Read json. + std::ifstream ifJson(config.second.json_name.c_str()); + if (ifJson.fail()) { + LOG(FATAL) << "json file not found: " << config.second.json_name; + } + const std::string json((std::istreambuf_iterator(ifJson)), + std::istreambuf_iterator()); + + // Create a graph executor. + std::istringstream istr(config.second.dev); + std::string str; + int device_type = 1, device_id = 0; + while (getline(istr, str, ';')) { + std::istringstream istr_dev(str); + std::string str_temp; + if (getline(istr_dev, str_temp)) { + device_type = stoi(str_temp); + } + if (getline(istr_dev, str_temp)) { + device_id = stoi(str_temp); + } + } + Module graph_module = (*graph_executor_create)(json, lib, device_type, device_id); + + // Load parameters. + TVMByteArray params_arr; + const char* params_file_name = config.second.params_name.c_str(); + std::ifstream if_param(params_file_name); + if (if_param.fail()) { + LOG(FATAL) << "params file not found: " << params_file_name; + } + const std::string params((std::istreambuf_iterator(if_param)), + std::istreambuf_iterator()); + params_arr.data = params.c_str(); + params_arr.size = params.length(); + auto load_params = graph_module.GetFunction("load_params"); + load_params(params_arr); + + // Put a graph executor module into the vector. + ret[config.first] = graph_module; + } + return ret; } -/* GetFunction can not be pure abstract function, implement an empty function for now. +/*! + * \brief Initialize the pipeline executor with a list of modules to be pipelined + * and config in JSON format. + * \param modules The module list used for building the pipeline. + * \param pipeline_json The configuration of modules dependencies. */ -PackedFunc PipelineRuntime::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { - return nullptr; +void PipelineExecutor::Init(const std::vector& modules, const std::string& pipeline_json) { + ICHECK(!modules.empty()) << "The graph executor module list is empty."; + // Use JSONReader to load pipeline configuration. + std::istringstream is(pipeline_json); + dmlc::JSONReader reader(&is); + PipelineConfig& pipeline_config = this->LoadPipelineConfig(&reader); + ICHECK(!pipeline_config.Empty()) << "The pipeline config information is empty."; + // Initialize the pipeline function class used for pipeline thread pool management + // and schedule etc. This function returns the number of output. + num_outputs_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config); + return; } -Module PipelineRuntimeCreate(const Array& m, - const std::string& pipeline_json) { - auto exec = make_object(); - exec->Init(m, pipeline_json); +Module PipelineExecutorCreate(const Array& m, const std::string& pipeline_json) { + ICHECK(!m.empty()) << "The module list is empty."; + auto exec = make_object(); + std::vector graph_modules; + for (auto mod : m) { + graph_modules.push_back(mod); + } + exec->Init(graph_modules, pipeline_json); + return Module(exec); +} + +Module PipelineExecutorLoad(const std::string& load_json, const std::string& pipeline_json) { + auto exec = make_object(); + std::istringstream is(load_json); + dmlc::JSONReader reader(&is); + ModuleConfig& mod_config = exec->LoadModuleConfig(&reader); + ICHECK(!mod_config.empty()) << "The module config is empty."; + std::vector modules = exec->CreateGraphModules(mod_config); + exec->Init(modules, pipeline_json); return Module(exec); } TVM_REGISTER_GLOBAL("tvm.pipeline_executor.create").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = PipelineRuntimeCreate(args[0], args[1]); + *rv = PipelineExecutorCreate(args[0], args[1]); +}); + +TVM_REGISTER_GLOBAL("tvm.pipeline_executor.load").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = PipelineExecutorLoad(args[0], args[1]); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/pipeline/pipeline_executor.h b/src/runtime/pipeline/pipeline_executor.h index c7625c62b724..a883ba25ec08 100644 --- a/src/runtime/pipeline/pipeline_executor.h +++ b/src/runtime/pipeline/pipeline_executor.h @@ -23,9 +23,16 @@ */ #ifndef TVM_RUNTIME_PIPELINE_PIPELINE_EXECUTOR_H_ #define TVM_RUNTIME_PIPELINE_PIPELINE_EXECUTOR_H_ + #include +#include +#include +#include #include +#include + +#include "pipeline_scheduler.h" namespace tvm { namespace runtime { /*! @@ -36,18 +43,23 @@ namespace runtime { * * This executor can be accessed by various language via TVM runtime PackedFunc API. */ -class TVM_DLL PipelineRuntime : public ModuleNode { +class TVM_DLL PipelineExecutor : public ModuleNode { public: /*! * \Return the type key of the executor. */ - const char* type_key() const final { return "PipelineRuntime"; } + const char* type_key() const final { return "PipelineExecutor"; } /*! - * \brief Initialize the pipeline executor with module array and json text. + * \brief Initialize the pipeline executor with module array and JSON text. * \param modules The module list used for building pipeline. * \param pipeline_json The configuration of modules dependencies. */ - void Init(const Array& modules, const std::string& pipeline_json); + void Init(const std::vector& modules, const std::string& pipeline_json); + /*! + * \brief Use the information of mod_config to create a list of graph executor. + * \param mod_config The configuration information generated by the library export function call. + */ + std::vector CreateGraphModules(const ModuleConfig& mod_config); /*! * \brief Give frontends an access to packed functions. * \param name The name of the function. @@ -55,6 +67,86 @@ class TVM_DLL PipelineRuntime : public ModuleNode { * \return The corresponding packed function. */ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); + + /*! + * \brief Get the number of outputs. + * + * \return The number of outputs. + */ + int NumOutputs() const { return num_outputs_; } + + /*!\brief Load the module files information.*/ + ModuleConfig& LoadModuleConfig(dmlc::JSONReader* reader) { + reader->BeginArray(); + while (reader->NextArrayItem()) { + std::string key; + reader->BeginObject(); + int mod_idx = -1; + std::string lib_name; + std::string json_name; + std::string params_name; + std::string dev; + while (reader->NextObjectItem(&key)) { + if (key == "mod_idx") { + reader->Read(&mod_idx); + } else if (key == "lib_name") { + reader->Read(&lib_name); + } else if (key == "json_name") { + reader->Read(&json_name); + } else if (key == "params_name") { + reader->Read(¶ms_name); + } else if (key == "dev") { + reader->Read(&dev); + } else { + LOG(FATAL) << "do not support key " << key; + } + } + ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx; + // Load the lib, json, and params information. + ICHECK(!lib_name.empty()) << "lib_name is empty."; + ICHECK(!json_name.empty()) << "json_name is empty."; + ICHECK(!params_name.empty()) << "params_name is empty."; + mod_config_[mod_idx] = GraphModuleLoadInfo(lib_name, json_name, params_name, dev); + } + return mod_config_; + } + + private: + /*!\brief The class used to execute and schedule the pipeline logic.*/ + PipelineScheduler pipeline_scheduler_; + /*!\brief The dependency information of each graph runtime module of the pipeline.*/ + PipelineConfig pipeline_config_; + /*!\brief The module information used to create the graph runtimes.*/ + ModuleConfig mod_config_; + /*!\brief How many outputs are in this pipeline executor.*/ + size_t num_outputs_ = 0; + /*!\brief Json loader.*/ + PipelineConfig& LoadPipelineConfig(dmlc::JSONReader* reader) { + reader->BeginArray(); + while (reader->NextArrayItem()) { + std::string key; + reader->BeginObject(); + int mod_idx = -1; + OutputMap output; + std::string dev; + while (reader->NextObjectItem(&key)) { + if (key == "mod_idx") { + reader->Read(&mod_idx); + } else if (key == "dev") { + reader->Read(&dev); + } else if (key == "output") { + reader->Read(&output); + } else { + LOG(FATAL) << "do not support key " << key; + } + } + ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx; + // Check if the output is successfully read. + ICHECK(!output.Empty()) << "Invalid output binding result."; + pipeline_config_.Insert(mod_idx, output); + } + return pipeline_config_; + } }; } // namespace runtime } // namespace tvm diff --git a/src/runtime/pipeline/pipeline_scheduler.cc b/src/runtime/pipeline/pipeline_scheduler.cc new file mode 100644 index 000000000000..82caf855a479 --- /dev/null +++ b/src/runtime/pipeline/pipeline_scheduler.cc @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "pipeline_scheduler.h" + +#include +#include +namespace tvm { +namespace runtime { +/*! + * \brief Initialize the pipeline. + * \param modules The list of graph executor modules. + * \param pipeline_conf The dependency information of each graph executor module. + */ +size_t PipelineScheduler::PipelineInit(const std::vector& modules, + const PipelineConfig& pipeline_config) { + graph_modules_ = modules; + int num_output = pipeline_config.GetGlobalOutputNum(); + return num_output; +} +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/pipeline/pipeline_scheduler.h b/src/runtime/pipeline/pipeline_scheduler.h new file mode 100644 index 000000000000..5ee127edffa3 --- /dev/null +++ b/src/runtime/pipeline/pipeline_scheduler.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_PIPELINE_PIPELINE_SCHEDULER_H_ +#define TVM_RUNTIME_PIPELINE_PIPELINE_SCHEDULER_H_ +#include +#include +#include + +#include +#include +#include +#include + +#include "pipeline_struct.h" +namespace tvm { +namespace runtime { +/*! + * \brief The class that executes the pipeline logic,it is used to initialize the thread pool, + execute and schedule pipeline tasks, allocate and manage memory, etc. + */ +class PipelineScheduler { + public: + /*! + * \brief Initialize the pipeline. + * \param modules The list of graph executor module. + * \param pipeline_config The dependency information of each graph executor module. + */ + size_t PipelineInit(const std::vector& modules, const PipelineConfig& pipeline_config); + + private: + /*!\brief The list of graph executors.*/ + std::vector graph_modules_; +}; +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_PIPELINE_PIPELINE_SCHEDULER_H_ diff --git a/src/runtime/pipeline/pipeline_struct.h b/src/runtime/pipeline/pipeline_struct.h new file mode 100644 index 000000000000..3cc9621702c1 --- /dev/null +++ b/src/runtime/pipeline/pipeline_struct.h @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_PIPELINE_PIPELINE_STRUCT_H_ +#define TVM_RUNTIME_PIPELINE_PIPELINE_STRUCT_H_ +#include +#include +#include + +#include +#include +#include +#include +/*! + * \brief All binding information of a output interface. + */ +struct OutputBindings { + /*!\brief Output interface binding information, 'int' is the index of the module that + * uses this output data as the input interface data, 'string' is the input interface name + * of the module. + */ + std::unordered_map bindings; + /*! The index value of the global interface to which the current output are bound.*/ + int global_output_index = std::numeric_limits::min(); + /*!\brief Whether this binding is bound to the PipelineExecutor output interface.*/ + bool IsGlobalOutput() const { return global_output_index >= 0; } + /*! + * \brief Create a module interface map from JSONReader. + * \param reader JSON reader. + */ + void Load(dmlc::JSONReader* reader) { + reader->BeginArray(); + while (reader->NextArrayItem()) { + std::string key; + reader->BeginObject(); + std::string input_name; + int mod_idx = std::numeric_limits::min(); + // Whether the output binding is global. + bool global_binding = false; + while (reader->NextObjectItem(&key)) { + if (key == "mod_idx") { + reader->Read(&mod_idx); + } else if (key == "input_name") { + reader->Read(&input_name); + } else if (key == "global_output_index") { + // There should be only one global binding. + ICHECK(global_output_index < 0); + reader->Read(&global_output_index); + // When the key value is 'global_output_index', it means that this output is bound to + // a global interface. + global_binding = true; + } else { + LOG(FATAL) << "do not support key " << key; + } + } + // When this output is bound to a global interface, check if the global interface index + // start from 0. + if (global_binding) { + ICHECK(global_output_index >= 0); + } else { + // When this output is bound to a graph executor module interface, check if the module + // index start from 0. + ICHECK(mod_idx >= 0); + bindings[mod_idx] = input_name; + } + } + } +}; + +/*! + * \brief The binding information of all outputs of a module. + */ +struct OutputMap { + /*! \brief Output binding map, 'int' is output interface index.*/ + std::unordered_map output_binding_map; + OutputMap& operator=(const OutputMap& output) { + output_binding_map = output.output_binding_map; + return *this; + } + + /*!\brief This function is used to verify whether OutputMap is successfully loaded. + * \return Return true to indicate that this class has not been successfully loaded. + */ + bool Empty() { return output_binding_map.empty(); } + /*! \brief The pipeline outputs is the final outputs of pipeline, this function is used to + * get how many pipeline outputs are in this Outputmap + * \return Number of pipeline outputs. + */ + size_t GetGlobalOutputNum(void) const { + size_t num_output = 0; + for (auto bindings : output_binding_map) { + num_output += bindings.second.IsGlobalOutput() ? 1 : 0; + } + return num_output; + } + + /*! + * \brief Create a output binding map from JSONReader. + * \param reader Json reader. + */ + void Load(dmlc::JSONReader* reader) { + reader->BeginArray(); + while (reader->NextArrayItem()) { + std::string key; + reader->BeginObject(); + int output_idx = -1; + OutputBindings binding; + while (reader->NextObjectItem(&key)) { + if (key == "output_idx") { + reader->Read(&output_idx); + } else if (key == "dependencies") { + reader->Read(&binding); + } else { + LOG(FATAL) << "do not support key " << key; + } + } + ICHECK(output_idx >= 0); + output_binding_map[output_idx] = binding; + } + } +}; +/*! + * \brief The binding or dependency information of each module output interface. + */ +struct PipelineConfig { + /*!\brief The key is the module index, this variable records all module pipeline configuration + * information. + */ + std::unordered_map config; + OutputMap& operator[](int key) { + ICHECK(config.find(key) != config.end()); + return config[key]; + } + + void Insert(int key, const OutputMap& map) { config[key] = map; } + + /*!\brief This function is used to verify whether config is loaded successfully. + * \return Return true to indicate that this class has not been successfully loaded. + */ + bool Empty() { return config.empty(); } + + /*! + * \brief Get the number of global outputs. + * \return The number of outputs the entire pipeline has. + */ + size_t GetGlobalOutputNum() const { + size_t num_output = 0; + for (auto mod_output : config) { + num_output += mod_output.second.GetGlobalOutputNum(); + } + return num_output; + } +}; +/*! + * \brief The information used to initialize the graph executor module, the information + * come from the export library function call. + */ +struct GraphModuleLoadInfo { + GraphModuleLoadInfo(const std::string& lib, const std::string& json, const std::string& params, + const std::string& device) + : lib_name(lib), json_name(json), params_name(params), dev(device) {} + GraphModuleLoadInfo() { ; } + std::string lib_name; + std::string json_name; + std::string params_name; + std::string dev; +}; +/*! The Module information of each module.The 'int' is module index. */ +using ModuleConfig = std::unordered_map; +#endif // TVM_RUNTIME_PIPELINE_PIPELINE_STRUCT_H_ diff --git a/tests/python/relay/test_pipeline_executor.py b/tests/python/relay/test_pipeline_executor.py index d9411c92c375..4a9b7eacdf65 100644 --- a/tests/python/relay/test_pipeline_executor.py +++ b/tests/python/relay/test_pipeline_executor.py @@ -16,6 +16,7 @@ # under the License. import pytest +import os import numpy as np import tvm import tvm.testing @@ -76,11 +77,11 @@ def get_manual_conf(mods, target): # The third output is the final output, the second output is for mod3, the first output # is for mod2 input. pipe_config1 = { - "mod_idx": 1, + "mod_idx": 0, "output": [ - {"output_idx": 0, "dependent": [{"mod_idx": 2, "input_name": "data_0"}]}, - {"output_idx": 1, "dependent": [{"mod_idx": 3, "input_name": "data_0"}]}, - {"output_idx": 2, "dependent": [{"mod_idx": 0, "input_name": "0"}]}, + {"output_idx": 0, "dependencies": [{"mod_idx": 1, "input_name": "data_0"}]}, + {"output_idx": 1, "dependencies": [{"mod_idx": 2, "input_name": "data_0"}]}, + {"output_idx": 2, "dependencies": [{"global_output_index": 0}]}, ], } mod_config[mods[0]] = { @@ -94,9 +95,9 @@ def get_manual_conf(mods, target): } pipe_config2 = { - "mod_idx": 2, + "mod_idx": 1, "output": [ - {"output_idx": 0, "dependent": [{"mod_idx": 3, "input_name": "data_1"}]}, + {"output_idx": 0, "dependencies": [{"mod_idx": 2, "input_name": "data_1"}]}, ], } mod_config[mods[1]] = { @@ -110,8 +111,8 @@ def get_manual_conf(mods, target): } pipe_config3 = { - "mod_idx": 3, - "output": [{"output_idx": 0, "dependent": [{"mod_idx": 0, "input_name": "1"}]}], + "mod_idx": 2, + "output": [{"output_idx": 0, "dependencies": [{"global_output_index": 1}]}], } mod_config[mods[2]] = { "pipeline": pipe_config3, @@ -128,7 +129,7 @@ def get_manual_conf(mods, target): def test_pipe_config_check(): # This function is used to trigger runtime error by applying wrong logic connection. - # Get the three pipeline modules here. + # Get three pipeline modules here. (mod1, mod2, mod3), dshape = get_mannual_mod() # The input or output name is illegal and expects a runtime error. @@ -179,10 +180,12 @@ def test_pipeline(): pipe_config = pipeline_executor.PipelineConfig() - # The global input named "data_0" will be connected to a input named "data_0" of mod1. + # The pipeline input named "data_0" will be connected to a input named "data_0" + # of mod1. pipe_config["input"]["data_0"].connect(pipe_config[mod1]["input"]["data_0"]) - # The global Input named "data_1" will be connected to a input named "data_1" of mod2. + # The pipeline Input named "data_1" will be connected to a input named "data_1" + # of mod2. pipe_config["input"]["data_1"].connect(pipe_config[mod2]["input"]["data_1"]) # The mod1 output[0] will be connected to a input named "data_0" of mod2. @@ -194,10 +197,10 @@ def test_pipeline(): # The mod2 output[2] will be connected to a input named "data_1" of mod3. pipe_config[mod2]["output"][0].connect(pipe_config[mod3]["input"]["data_1"]) - # The mod1 output[2] will be connected to global output[1]. + # The mod1 output[2] will be connected to pipeline output[0]. pipe_config[mod1]["output"][2].connect(pipe_config["output"]["0"]) - # The mod3 output[0] will be connected to global output[2]. + # The mod3 output[0] will be connected to pipeline output[1]. pipe_config[mod3]["output"][0].connect(pipe_config["output"]["1"]) # Print configueration (print(pipe_config)), the result looks like following. # @@ -231,9 +234,21 @@ def test_pipeline(): with tvm.transform.PassContext(opt_level=3): pipeline_mod_factory = pipeline_executor.build(pipe_config) + # Export the parameter configuration to a file. + directory_path = tvm.contrib.utils.tempdir().temp_dir + # If the directory does not exist, create it. + if not os.path.exists(directory_path): + os.makedirs(directory_path) + config_file_name = pipeline_mod_factory.export_library(directory_path) + + # Use the output of build to create and initialize PipelineModule. pipeline_module = pipeline_executor.PipelineModule(pipeline_mod_factory) assert pipeline_module + # Use the import function to create and initialize PipelineModule. + pipeline_module_test = pipeline_executor.PipelineModule.load_library(config_file_name) + assert pipeline_module_test.num_outputs == 2 + if __name__ == "__main__": pytest.main([__file__]) From 4c0026495551d2797bf5df016386c712ba6f4a05 Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Thu, 14 Oct 2021 05:16:39 +0100 Subject: [PATCH 27/84] [CI] Pre-build Reference System Dependencies (#9270) Building these dependencies from scratch in each test was taking much longer than really necessary. Before: ``` $ time python3 -m pytest tests/python/contrib/test_ethosu/test_codegen.py::test_tflite_depthwise_conv2d[strides0-dilation0-SAME-kernel_shape0-relu-ifm_shape0-ethos-u55-256] real 0m19.982s user 0m13.255s sys 0m3.403s ``` After: ``` $ time python3 -m pytest tests/python/contrib/test_ethosu/test_codegen.py::test_tflite_depthwise_conv2d[strides0-dilation0-SAME-kernel_shape0-relu-ifm_shape0-ethos-u55-256] real 0m10.963s user 0m5.516s sys 0m2.232s ``` --- docker/install/ubuntu_install_ethosu_driver_stack.sh | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docker/install/ubuntu_install_ethosu_driver_stack.sh b/docker/install/ubuntu_install_ethosu_driver_stack.sh index c5ab20e18277..db8b47399390 100755 --- a/docker/install/ubuntu_install_ethosu_driver_stack.sh +++ b/docker/install/ubuntu_install_ethosu_driver_stack.sh @@ -92,3 +92,13 @@ cd "${ethosu_dir}" git clone "https://github.com/ARM-software/CMSIS_5.git" cmsis cd cmsis git checkout -f tags/${cmsis_ver} + +# Build Driver +mkdir ${ethosu_dir}/core_driver/build && cd ${ethosu_dir}/core_driver/build +cmake -DCMAKE_TOOLCHAIN_FILE=${ethosu_dir}/core_platform/cmake/toolchain/arm-none-eabi-gcc.cmake -DETHOSU_LOG_SEVERITY=debug -DTARGET_CPU=cortex-m55 .. +make + +# Build NN Library +mkdir ${ethosu_dir}/cmsis/CMSIS/NN/build/ && cd ${ethosu_dir}/cmsis/CMSIS/NN/build/ +cmake .. -DCMAKE_TOOLCHAIN_FILE=${ethosu_dir}/core_platform/cmake/toolchain/arm-none-eabi-gcc.cmake -DTARGET_CPU=cortex-m55 -DBUILD_CMSIS_NN_FUNCTIONS=YES +make From 523eb12a1a9ce92777afbe90e0b87f47712cad6a Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Wed, 13 Oct 2021 23:26:18 -0500 Subject: [PATCH 28/84] [Pytest] Sort unit tests before running. (#9188) * [Pytest] Sort unit tests before running. By default, pytest will sort tests to maximize the re-use of fixtures. However, this assumes that all fixtures have an equal cost to generate, and no caches outside of those managed by pytest. A fixture for a `tvm.testing.parameter` is effectively free, while a fixture maintaining a cache of reference data `tvm.testing.utils._fixture_cache` be quite large. Since most of the TVM fixtures are specific to a python function, sort the test ordering by python function, so that tvm.testing.utils._fixture_cache can be cleared sooner rather than later. * Updated TestTargetAutoParametrization When sorting the tests, the order of parametrizations may change. Therefore, the tests checking for automatic target parametrization shouldn't depend on order. --- python/tvm/testing/plugin.py | 20 +++++++++++++++++++ .../unittest/test_tvm_testing_features.py | 15 ++++++++------ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py index 0413c44208b0..2cb228c357e5 100644 --- a/python/tvm/testing/plugin.py +++ b/python/tvm/testing/plugin.py @@ -74,6 +74,7 @@ def pytest_collection_modifyitems(config, items): # pylint: disable=unused-argument _count_num_fixture_uses(items) _remove_global_fixture_definitions(items) + _sort_tests(items) @pytest.fixture @@ -236,6 +237,25 @@ def _remove_global_fixture_definitions(items): delattr(module, name) +def _sort_tests(items): + """Sort tests by file/function. + + By default, pytest will sort tests to maximize the re-use of + fixtures. However, this assumes that all fixtures have an equal + cost to generate, and no caches outside of those managed by + pytest. A tvm.testing.parameter is effectively free, while + reference data for testing may be quite large. Since most of the + TVM fixtures are specific to a python function, sort the test + ordering by python function, so that + tvm.testing.utils._fixture_cache can be cleared sooner rather than + later. + + Should be called from pytest_collection_modifyitems. + + """ + items.sort(key=lambda item: item.location) + + def _target_to_requirement(target): if isinstance(target, str): target = tvm.target.Target(target) diff --git a/tests/python/unittest/test_tvm_testing_features.py b/tests/python/unittest/test_tvm_testing_features.py index cbcdc4356250..c00fc02c4331 100644 --- a/tests/python/unittest/test_tvm_testing_features.py +++ b/tests/python/unittest/test_tvm_testing_features.py @@ -46,8 +46,11 @@ def test_device_parametrization(self, dev): self.devices_used.append(dev) def test_all_targets_used(self): - assert self.targets_used == self.enabled_targets - assert self.devices_used == self.enabled_devices + assert sorted(self.targets_used) == sorted(self.enabled_targets) + + def test_all_devices_used(self): + sort_key = lambda dev: (dev.device_type, dev.device_id) + assert sorted(self.devices_used, key=sort_key) == sorted(self.enabled_devices, key=sort_key) targets_with_explicit_list = [] @@ -70,9 +73,9 @@ def test_exclude_target(self, target): self.targets_with_exclusion.append(target) def test_all_nonexcluded_targets_ran(self): - assert self.targets_with_exclusion == [ - target for target in self.enabled_targets if not target.startswith("llvm") - ] + assert sorted(self.targets_with_exclusion) == sorted( + [target for target in self.enabled_targets if not target.startswith("llvm")] + ) run_targets_with_known_failure = [] @@ -85,7 +88,7 @@ def test_known_failing_target(self, target): assert "llvm" not in target def test_all_targets_ran(self): - assert self.run_targets_with_known_failure == self.enabled_targets + assert sorted(self.run_targets_with_known_failure) == sorted(self.enabled_targets) @tvm.testing.known_failing_targets("llvm") @tvm.testing.parametrize_targets("llvm") From 7e014a441b35001318a3687cfb99943f5a4065a7 Mon Sep 17 00:00:00 2001 From: CircleSpin <2keepconnected@gmail.com> Date: Thu, 14 Oct 2021 01:46:53 -0400 Subject: [PATCH 29/84] [ONNX] [Relay] Resize Opset 13 (#9265) * Fix handling of optional inputs. * Missed one test in the ignore list. * split 11 and 13 * removed comments, adjusted for git review Co-authored-by: Josh Fromm Co-authored-by: Matthew Co-authored-by: CircleSpin --- python/tvm/relay/frontend/onnx.py | 44 +++++++++++++++++----- tests/python/frontend/onnx/test_forward.py | 9 +---- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3479e1e7c36e..5c112c7dfce0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2697,6 +2697,40 @@ def _impl_v10(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): + scale = inputs[2] + scale_shape = infer_shape(scale) + if len(inputs) == 4: + assert ( + len(scale_shape) == 0 or scale_shape[0] == 0 + ), "One of scale or size should be passed, not both." + size = inputs[3] + else: + assert len(scale_shape) != 0, "One of scale or size should be passed." + size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale + return cls.v11_13_common(inputs, size, attr, params) + + @classmethod + def _impl_v13(cls, inputs, attr, params): + scale = inputs[2] + size = inputs[3] + if size is not None: + assert scale is None, "One of scale or size should be passed, not both." + else: + scale_type = infer_type(scale) + scale_shape = scale_type.checked_type.shape + scale_dtype = scale_type.checked_type.dtype + assert len(scale_shape) != 0, "One of scale or size should be passed." + size = _op.cast(shape_of(inputs[0]), scale_dtype) * scale + + return cls.v11_13_common(inputs, size, attr, params) + + @classmethod + def v11_13_common(cls, inputs, size, attr, params): + """ + Resize v11 and Resize v13 are identical except in how + they handle the passing of scale and size. This utility + provides the implementation for both + """ ndims = len(infer_shape(inputs[0])) mode = attr.get("mode").decode("ascii") if mode == "nearest": @@ -2715,16 +2749,6 @@ def _impl_v11(cls, inputs, attr, params): alpha = attr.get("cubic_coeff_a", -0.75) exclude = attr.get("exclude_outside", 0) - scale = inputs[2] - scale_shape = infer_shape(scale) - if len(inputs) == 4: - assert ( - len(scale_shape) == 0 or scale_shape[0] == 0 - ), "One of scale or size should be passed, not both." - size = inputs[3] - else: - assert len(scale_shape) != 0, "One of scale or size should be passed." - size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale out_size = fold_constant(_op.strided_slice(size, [2], [4])) out = None if ndims == 3: diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 2301747034dd..dd1c77330986 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3970,6 +3970,7 @@ def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, ex make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales), ] input_names = ["X", "roi", "scales"] + if oshape != []: nodes.append( make_constant_node("sizes", onnx.TensorProto.INT64, (len(oshape),), oshape) @@ -4954,15 +4955,7 @@ def verify_eyelike(indata): "test_reduce_sum_keepdims_random", "test_reduce_sum_negative_axes_keepdims_example", "test_reduce_sum_negative_axes_keepdims_random", - "test_resize_downsample_sizes_cubic", - "test_resize_downsample_sizes_linear_pytorch_half_pixel", - "test_resize_downsample_sizes_nearest", "test_resize_tf_crop_and_resize", - "test_resize_upsample_sizes_cubic", - "test_resize_upsample_sizes_nearest", - "test_resize_upsample_sizes_nearest_ceil_half_pixel", - "test_resize_upsample_sizes_nearest_floor_align_corners", - "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", "test_rnn_seq_length", "test_round", "test_scan9_sum", From 575ac8678b8dd6600ed55246e2ec16b2f3a16aca Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Thu, 14 Oct 2021 06:48:28 +0100 Subject: [PATCH 30/84] Skip onnx test cases if no onnx (#9272) This was missing a guard which meant VS Code errored on test collection. --- .../test_quantization_accuracy_for_vit.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/python/nightly/quantization/test_quantization_accuracy_for_vit.py b/tests/python/nightly/quantization/test_quantization_accuracy_for_vit.py index 8cecbf97c001..484ec23b369a 100644 --- a/tests/python/nightly/quantization/test_quantization_accuracy_for_vit.py +++ b/tests/python/nightly/quantization/test_quantization_accuracy_for_vit.py @@ -14,15 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm import os import sys -from tvm import relay -from tvm.relay import quantize as qtz import logging + +import pytest + +pytest.importorskip("onnx") + import onnx + +import tvm +from tvm import relay +from tvm.relay import quantize as qtz import tvm.testing -import mxnet as mx from test_quantization_accuracy import Config, get_val_data, eval_acc logging.basicConfig(level=logging.INFO) From d153676afdde2bf9a4fdf137d235e4f26d362bc6 Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Thu, 14 Oct 2021 01:02:19 -0700 Subject: [PATCH 31/84] Update TVM_LOG_DEBUG for IR tracing. (#9278) * Update TVM_LOG_DEBUG for IR tracing. Forgot to do this when I switched to VLOG, sorry. * Woops, remove src/ prefix. --- docs/install/from_source.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index b28c18162437..23be3198bf7c 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -107,7 +107,7 @@ The configuration of TVM can be modified by editing `config.cmake` and/or by pas .. code:: bash - export TVM_LOG_DEBUG=1 + export TVM_LOG_DEBUG="ir/transform.cc=1;relay/ir/transform.cc=1" - TVM requires LLVM for for CPU codegen. We highly recommend you to build with the LLVM support on. From 594f23d976f09d4aa300f84de0c9a7906b71eeee Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Thu, 14 Oct 2021 11:09:02 +0300 Subject: [PATCH 32/84] [Core][Build] Move build module transformations and utilities to C++ (#9103) * Initial investigation * More progress! * More progress / notes * rewrite build_for_device mostly in c++ * More progress * Initial split of transformations applied to device and host as post split action from mixed module * Combine duplicate passes after spliting mod on aot and vm flows * Minor cleanup * Move target mangling to driver_api.cc * Move more build utlities to cpp driver api * [Build][WIP] Moving build utilities to C++ from Python * [Build] Remove comments * [lint] Pass black * More formating * Move more build functionality into cpp * Remove comments * Remove unused defs and imports * Address PR comments * More PR comments * More comments * More comments * Add comments on the new split function * Fix PR comments on clarity * Test CI * Fix format * Refactor build * Expose splitted composite passes to python * Format files * Test fix * Fix for annotating entry funcs on code targeting CPU * Prevent entry funcs to be annotated when compiling for CPU with C runtime enabled * Guard for aot executor entry * Sphix format * Sanity fix * Sphinx fix Co-authored-by: electriclilies --- include/tvm/driver/driver_api.h | 30 ++++ python/tvm/driver/build_module.py | 125 ++----------- python/tvm/relay/build_module.py | 6 +- src/driver/driver_api.cc | 280 +++++++++++++++++++++--------- src/relay/backend/vm/compiler.cc | 2 +- 5 files changed, 238 insertions(+), 205 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 418d532fdd5f..45a938247cc8 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -30,6 +30,7 @@ #define TVM_DRIVER_DRIVER_API_H_ #include +#include #include #include #include @@ -43,6 +44,34 @@ #include namespace tvm { +using tvm::transform::Pass; + +/*! + * \brief Configures and returns the composite Pass for the fused module (pre split) that contains + * device and host code. + * \param mixed_mod The original mixed module. + * \param target The device Target. + * \return The composite Pass for the fused module. +// */ +TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target); + +/*! + * \brief Configures and returns the composite Pass for the device Target after device/host from + * mixed module. + * \param mixed_mod The optimized mixed module. + * \param target The device Target. + * \return The composite Pass for the device module. + */ +TVM_DLL transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target); + +/*! + * \brief Configures and returns the composite Pass for the host Target after device/host from mixed + * module. + * \param mixed_mod The optimized mixed module. + * \param target_host The host Target. + * \return The composite Pass for the host module. + */ +TVM_DLL transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host); /*! * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) @@ -136,6 +165,7 @@ TVM_DLL runtime::Module build(const Map& input, const Target& * \return The built module that contains code for different processors. */ TVM_DLL runtime::Module build(const Map& input, const Target& target_host); + } // namespace tvm #endif // TVM_DRIVER_DRIVER_API_H_ diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a7ebc00c315f..429b3e1727cc 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -16,27 +16,23 @@ # under the License. # pylint: disable=invalid-name -"""The build utils in python. -""" +"""The build utils in python.""" from typing import Union, Optional, List, Mapping -import warnings import tvm.tir from tvm.runtime import Module from tvm.runtime import ndarray from tvm.ir import container -from tvm.ir import CallingConv from tvm.tir import PrimFunc from tvm.ir.module import IRModule -from tvm.ir.transform import PassContext -from tvm.target import codegen from tvm.te import tensor from tvm.te import schedule from tvm.target import Target from tvm.tir.buffer import Buffer from tvm.tir.expr import Var +from tvm.driver import _ffi_api as _driver_ffi from . import _ffi_api as ffi @@ -104,8 +100,8 @@ def lower( args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] The argument lists to the function for TE schedule. - It should be None if we want to lower TensorIR. + It should be None if we want to lower TensorIR. name : str The name of the result function. @@ -132,98 +128,6 @@ def lower( raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) -def _build_for_device(input_mod, target, target_host): - """Build the lowered functions for a device with the given compilation - target. - - Parameters - ---------- - input_mod : IRModule - The schedule to be built. - - target : str or :any:`tvm.target.Target` - The target and option of the compilation. - - target_host : str or :any:`tvm.target.Target` - The host compilation target. - - Returns - ------- - fhost : IRModule - The host IRModule. - - mdev : tvm.module - A module that contains device code. - """ - target, target_host = Target.check_and_update_host_consist(target, target_host) - device_type = ndarray.device(target.kind.name, 0).device_type - - mod_mixed = input_mod - mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) - - opt_mixed = [ - tvm.tir.transform.VerifyMemory(), - tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), - ] - if len(mod_mixed.functions) == 1: - opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))] - - if PassContext.current().config.get("tir.detect_global_barrier", False): - opt_mixed += [tvm.tir.transform.ThreadSync("global")] - opt_mixed += [ - tvm.tir.transform.ThreadSync("shared"), - tvm.tir.transform.ThreadSync("warp"), - tvm.tir.transform.InferFragment(), - tvm.tir.transform.LowerThreadAllreduce(), - tvm.tir.transform.MakePackedAPI(), - tvm.tir.transform.SplitHostDevice(), - ] - mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed) - - # device optimizations - opt_device = tvm.transform.Sequential( - [ - tvm.tir.transform.Filter( - lambda f: "calling_conv" in f.attrs - and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH - ), - tvm.tir.transform.LowerWarpMemory(), - tvm.tir.transform.Simplify(), - tvm.tir.transform.LowerDeviceStorageAccessInfo(), - tvm.tir.transform.LowerCustomDatatypes(), - tvm.tir.transform.LowerIntrin(), - ] - ) - mod_dev = opt_device(mod_mixed) - - # host optimizations - opt_host = tvm.transform.Sequential( - [ - tvm.tir.transform.Filter( - lambda f: "calling_conv" not in f.attrs - or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH - ), - tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)), - tvm.tir.transform.LowerTVMBuiltin(), - tvm.tir.transform.LowerDeviceStorageAccessInfo(), - tvm.tir.transform.LowerCustomDatatypes(), - tvm.tir.transform.LowerIntrin(), - tvm.tir.transform.CombineContextCall(), - ] - ) - mod_host = opt_host(mod_mixed) - - if device_type == ndarray.cpu(0).device_type and target_host == target: - assert len(mod_dev.functions) == 0 - if "gpu" in target.keys and len(mod_dev.functions) == 0: - warnings.warn( - "Specified target %s, but cannot find device code, did you do " "bind?" % target - ) - - rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None - return mod_host, rt_mod_dev - - def build( inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, @@ -237,7 +141,8 @@ def build( Parameters ---------- - inputs : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]] + inputs : Union[tvm.te.schedule.Schedule, + tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]] The input to be built args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] @@ -253,7 +158,7 @@ def build( setup the dimensions and parameters correctly. target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, - otherwise a stackvm intepreter is used. + otherwise a stackvm interpreter is used. name : Optional[str] The name of result function. @@ -350,21 +255,11 @@ def build( target_input_mod, target_host ) - mod_host_all = tvm.IRModule({}) - - device_modules = [] - for tar, input_mod in target_input_mod.items(): - mod_host, mdev = _build_for_device(input_mod, tar, target_host) - mod_host_all.update(mod_host) - device_modules.append(mdev) - - # Generate a unified host module. - rt_mod_host = codegen.build_module(mod_host_all, target_host) + rt_mod_host = _driver_ffi.finalize_module(target_input_mod, target_host) - # Import all modules. - for mdev in device_modules: - if mdev: - rt_mod_host.import_module(mdev) + target_input_mod, target_host = Target.check_and_update_host_consist( + target_input_mod, target_host + ) if not isinstance(target_host, Target): target_host = Target(target_host) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index c67ac1dc423d..f1686d2a03bb 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -123,7 +123,7 @@ def build( to setup the dimensions and parameters correctly. target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, - otherwise a stackvm intepreter is used. + otherwise a stackvm interpreter is used. params : dict of str to NDArray Input parameters to the graph that do not change @@ -303,7 +303,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" setup the dimensions and parameters correctly. target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, - otherwise a stackvm intepreter is used. + otherwise a stackvm interpreter is used. params : dict of str to NDArray Input parameters to the graph that do not change @@ -452,7 +452,7 @@ def bind_params_by_name(func, params): class GraphExecutor(_interpreter.Executor): """Wrapper around Executor interface. - This executor is used for debug and testing purpoes. + This executor is used for debug and testing purposes. Parameters ---------- diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 2c6fbc2eb76d..e659421c23c4 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -42,17 +42,26 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; +using tvm::Array; +using tvm::transform::Pass; bool LLVMEnabled() { const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); return pf != nullptr; } +bool ShouldAnnotateEntryFunc(const Target target, const IRModule mod) { + const bool aot_executor = (target->GetAttr("executor").value_or("") == "aot"); + const bool single_entry_func = (mod->functions.size() == 1); + return single_entry_func && !aot_executor; +} + /*! \return The default host target for a given device target */ Target DefaultTargetHost(Target target) { if (target.defined() && target->kind->device_type == kDLCPU) { @@ -155,6 +164,13 @@ transform::Pass BindTarget(Target target) { return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {}); } +static transform::Pass AnnotateEntryFunc(bool b) { + auto fpass = [b](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true)); + }; + return tir::transform::CreatePrimFuncPass(fpass, 0, "AnnotateEntryFunc", {}); +} + template transform::Pass Filter(FCond fcond) { auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { @@ -184,7 +200,7 @@ Array CreatePassList(bool disable_loop_partition) { Array user_lower_phase2 = Array(); Array user_lower_phase3 = Array(); - // phase pasees is of the form + // phase passes is of the form // [[phase_number, pass], [phase_number, pass]... ] for (Array phase_pass : add_lower_pass) { const IntImmNode* phase_num = phase_pass[0].as(); @@ -266,6 +282,11 @@ IRModule LowerWithPassList(IRModule mod, Array pass_list) return mod; } +IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { + mod = seq(std::move(mod)); + return mod; +} + IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds) { // Convert te schedule to IRModule @@ -373,97 +394,96 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode); }); -std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, - const Target& target_host_arg, - const transform::PassContext& pass_ctx) { +/** + * This function takes the input module that contains both the device and host opts. + * Then, it applies transformation on the original module before splitting into separate modules for + * device and host. Then it also applies transformations on the new splitted modules. + */ +std::pair SplitMixedModule(IRModule mod_mixed, const Target& target_arg, + const Target& target_host_arg) { Target target = target_arg, target_host = target_host_arg; CheckAndUpdateHostConsistency(&target, &target_host); - Array mixed_pass_list = {BindTarget(target), - tir::transform::VerifyMemory()}; - mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); - if (pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value()) { - mixed_pass_list.push_back(tir::transform::ThreadSync("global")); - } - mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); - mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); - mixed_pass_list.push_back(tir::transform::InferFragment()); - mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); + ICHECK(mod_mixed.defined()) << "This module must be defined"; - if (target->GetAttr("unpacked-api").value_or(Bool(false))) { - mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); - } else { - mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1)); - } + mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); - mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host)); - auto opt_mixed = transform::Sequential(mixed_pass_list); - mod_mixed = opt_mixed(std::move(mod_mixed)); - - // We make an assumption here that the overriden host target - // can be used alongside the default host codegen based on device type - // this is so the correct code generator is used later instead of overriding the target. - // We need better support for inserting multiple kDLCPU targets as our current options - // are kDeviceKernelLaunch or not - Target overriden_host_target = target_host; - if (target->kind->device_type == target_host->kind->device_type) { - overriden_host_target = target; - } - auto host_pass_list = { - Filter([](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != - CallingConv::kDeviceKernelLaunch; - }), - BindTarget(overriden_host_target), - tir::transform::LowerTVMBuiltin(), - tir::transform::LowerCustomDatatypes(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), - tir::transform::CombineContextCall(), - }; - auto opt_host = transform::Sequential(host_pass_list); - ICHECK(mod_mixed.defined()) << "This module must be defined"; - auto mhost = opt_host(mod_mixed); - - // device pipeline - auto device_pass_list = { - Filter([](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == - CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target), - tir::transform::LowerWarpMemory(), - tir::transform::Simplify(), - tir::transform::LowerCustomDatatypes(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), - }; - auto opt_device = transform::Sequential(device_pass_list); - auto mdevice = opt_device(mod_mixed); + IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target)); - // some final misc checks. auto keys = target->GetKeys(); + + CheckAndUpdateHostConsistency(&target, &target_host); + bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); - if (target_is_gpu && mdevice->functions.size() == 0) { - LOG(WARNING) << "Specified target " << target->str() - << " but cannot find device code. Did you forget to bind?"; + if (target_is_gpu && device_mod->functions.size() == 0) { + DLOG(WARNING) << "Specified target " << target->str() + << " but cannot find device code. Did you forget to bind?"; + } + + return {host_mod, device_mod}; +} + +runtime::Module FinalizeModule(const Map& inputs_arg, const Target& host_target) { + std::vector device_modules; + Map inputs = inputs_arg; + Target target_host = host_target; + + CheckAndUpdateHostConsistency(&inputs, &target_host); + + if (!target_host.defined()) { + for (const auto& it : inputs) { + if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type == kDLMicroDev) { + target_host = it.first; + break; + } + } + } + + if (!target_host.defined()) { + target_host = DefaultTargetHost(target_host); } - if (target->kind->device_type == kDLCPU && target_host == target) { - // TODO(@jroesch): This check is no longer true we need to figure out if we care about this. - // We need to relax this check for just TIR functions. - // ICHECK(mdevice->functions.empty()) << "No device code should be generated when target " - // << "and host_target are both llvm target." - // << "\n"; + // Update target host for all targets + CheckAndUpdateHostConsistency(&inputs, &target_host); + + IRModule mhost_all = IRModule(Map()); + + ICHECK(mhost_all.defined()) << "The host module must be defined"; + + for (const auto& it : inputs) { + if (it.second.defined()) { + auto pair = SplitMixedModule(it.second, it.first, target_host); + auto& host_mod = pair.first; + auto& device_mod = pair.second; + + ICHECK(host_mod.defined()) << "The split host module must be defined"; + + ICHECK(mhost_all.defined()) << "The host module must be defined"; + + mhost_all->Update(host_mod); + + if (device_mod->functions.size() != 0) { + device_modules.push_back(codegen::Build(device_mod, it.first)); + } + } } - return {mhost, mdevice}; + runtime::Module complete_mod = codegen::Build(mhost_all, target_host); + for (const auto& it : device_modules) { + if (it.operator->()) { + complete_mod.Import(it); + } + } + return complete_mod; } -// Can we make this take one annotated IRModule? -// -// Build for heterogeneous execution. +TVM_REGISTER_GLOBAL("driver.finalize_module") + .set_body_typed([](const Map& inputs_arg, Target host_target) { + return FinalizeModule(inputs_arg, host_target); + }); + runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { auto pass_ctx = transform::PassContext::Current(); @@ -498,11 +518,11 @@ runtime::Module build(const Map& inputs_arg, const Target& tar if (it.second.defined()) { const Target& target = it.first; const IRModule& ir_module = it.second; - auto pair = SplitDevHostFuncs(ir_module, target, target_host, pass_ctx); - auto& mhost = pair.first; - auto& mdevice = pair.second; + auto pair = SplitMixedModule(ir_module, target, target_host); + auto& host_mod = pair.first; + auto& device_mod = pair.second; - ICHECK(mhost.defined()) << "The split host module must be defined"; + ICHECK(host_mod.defined()) << "The split host module must be defined"; ICHECK(mhost_all.defined()) << "The host module must be defined"; @@ -513,19 +533,18 @@ runtime::Module build(const Map& inputs_arg, const Target& tar bool overrides_host_target = target->kind->device_type == target_host->kind->device_type; bool non_host_target_kind = target->kind != target_host->kind; if (overrides_host_target && non_host_target_kind) { - device_modules.push_back(codegen::Build(mhost, it.first)); + device_modules.push_back(codegen::Build(host_mod, it.first)); } else { - mhost_all->Update(mhost); + mhost_all->Update(host_mod); } - if (mdevice->functions.size() != 0) { - device_modules.push_back(codegen::Build(mdevice, it.first)); + if (device_mod->functions.size() != 0) { + device_modules.push_back(codegen::Build(device_mod, it.first)); } } } runtime::Module mhost = codegen::Build(mhost_all, target_host); - // Import all modules for (const auto& it : device_modules) { if (it.operator->()) { mhost.Import(it); @@ -556,8 +575,97 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, const Target& target_host_arg) { auto target = target_arg, target_host = target_host_arg; CheckAndUpdateHostConsistency(&target, &target_host); + // More maps of target and target host Map inputs = {{target, funcs}}; return build(inputs, target_host); } +transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { + transform::PassContext pass_ctx = transform::PassContext::Current(); + + Array mixed_pass_list; + + mixed_pass_list.push_back(BindTarget(target)); + + mixed_pass_list.push_back(tir::transform::VerifyMemory()); + mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); + + if (ShouldAnnotateEntryFunc(target, mixed_mod)) { + mixed_pass_list.push_back(AnnotateEntryFunc(true)); + } + + bool detect_global_barrier = + pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); + if (detect_global_barrier) { + mixed_pass_list.push_back(tir::transform::ThreadSync("global")); + } + + mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); + mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); + mixed_pass_list.push_back(tir::transform::InferFragment()); + mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); + + if (target->GetAttr("unpacked-api").value_or(Bool(false))) { + mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); + } else { + mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1)); + } + mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + + return transform::Sequential(mixed_pass_list); +} + +TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") + .set_body_typed([](IRModule mixed_mod, Target target) { + return MixedModulePassManager(mixed_mod, target); + }); + +transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { + Array host_pass_list; + host_pass_list.push_back(Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != + CallingConv::kDeviceKernelLaunch; + })); + + ICHECK(mixed_mod.defined()) << "This module must be defined"; + + host_pass_list.push_back(BindTarget(target_host)); + + host_pass_list.push_back(tir::transform::LowerTVMBuiltin()); + host_pass_list.push_back(tir::transform::LowerCustomDatatypes()); + host_pass_list.push_back(tir::transform::LowerIntrin()); + host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); + host_pass_list.push_back(tir::transform::CombineContextCall()); + + return transform::Sequential(host_pass_list); +} + +TVM_REGISTER_GLOBAL("driver.host_mod_passes") + .set_body_typed([](IRModule mixed_mod, Target target_host) { + return HostModulePassManager(mixed_mod, target_host); + }); + +transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) { + Array device_pass_list; + device_pass_list.push_back(Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDeviceKernelLaunch; + })); + + device_pass_list.push_back(BindTarget(target)); + + device_pass_list.push_back(tir::transform::LowerWarpMemory()); + device_pass_list.push_back(tir::transform::Simplify()); + device_pass_list.push_back(tir::transform::LowerCustomDatatypes()); + device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); + device_pass_list.push_back(tir::transform::LowerIntrin()); + + return transform::Sequential(device_pass_list); +} + +TVM_REGISTER_GLOBAL("driver.device_mod_passes") + .set_body_typed([](IRModule mixed_mod, Target target_host) { + return DeviceModulePassManager(mixed_mod, target_host); + }); + } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 36cd0c7f406d..70ad2ccc992e 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -824,7 +824,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { /*! * \brief Compile a pattern match expression - * It first converts the pattern match expression into a desicision tree, the condition + * It first converts the pattern match expression into a decision tree, the condition * could be object comparison or variable binding. If any of the condition fails in a clause, * the decision tree switches to check the conditions of next clause and so on. If no clause * matches the value, a fatal node is inserted. From 217763279371bc8eec4b61152e8b2deaef4b8a8c Mon Sep 17 00:00:00 2001 From: Hua Jiang Date: Thu, 14 Oct 2021 06:52:41 -0700 Subject: [PATCH 33/84] [Tutorial] Fix vta vision detection tutorial 'sphinx' style error. (#9279) Issue: Some bash code in this tutorial does not get syntax highlighting because of the format errors. Solution: Fix the 'sphinx' 'rst' style error. --- vta/tutorials/frontend/deploy_detection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vta/tutorials/frontend/deploy_detection.py b/vta/tutorials/frontend/deploy_detection.py index 771801851a48..cbd22a752049 100644 --- a/vta/tutorials/frontend/deploy_detection.py +++ b/vta/tutorials/frontend/deploy_detection.py @@ -34,15 +34,15 @@ # # .. code-block:: bash # -# pip3 install "Pillow<7" +# pip3 install "Pillow<7" # # YOLO-V3-tiny Model with Darknet parsing have dependancy with CFFI and CV2 library, # we need to install CFFI and CV2 before executing this script. # -# pip3 install "Pillow<7" +# .. code-block:: bash # -# pip3 install cffi -# pip3 install opencv-python +# pip3 install cffi +# pip3 install opencv-python # # Now return to the python code. Import packages. From 59b3cf7a8a49f78dd60bda8982aace8ee6dc3d7e Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Thu, 14 Oct 2021 18:03:49 +0100 Subject: [PATCH 34/84] Reset sphinx-gallery version to 0.4.0 (#9280) Seeing: ``` ERROR: Could not find a version that satisfies the requirement sphinx-gallery==0.4.1 (from versions: 0.0.4, 0.0.5, 0.0.6, 0.0.7, 0.0.8, 0.0.10, 0.0.11.post1, 0.1.0, 0.1.1, 0.1.2, 0.1.3, 0.1.4, 0.1.5, 0.1.6, 0.1.7, 0.1.8, 0.1.9, 0.1.10, 0.1.11, 0.1.12, 0.1.13, 0.2.0, 0.3.0, 0.3.1, 0.4.0, 0.5.0, 0.6.0, 0.6.1, 0.6.2, 0.7.0, 0.8.0, 0.8.1, 0.8.2, 0.9.0, 0.10.0) ERROR: No matching distribution found for sphinx-gallery==0.4.1 ``` This was changed in https://github.com/apache/tvm/pull/9115 --- docker/install/ubuntu_install_sphinx.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/install/ubuntu_install_sphinx.sh b/docker/install/ubuntu_install_sphinx.sh index 66720d411832..12208bbe6643 100755 --- a/docker/install/ubuntu_install_sphinx.sh +++ b/docker/install/ubuntu_install_sphinx.sh @@ -29,5 +29,5 @@ pip3 install \ matplotlib \ sphinx \ sphinx_autodoc_annotation \ - sphinx-gallery==0.4.1 \ + sphinx-gallery==0.4.0 \ sphinx_rtd_theme From 95a20315d0d3403e5eea3bfe6ed5e942a51b169d Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Thu, 14 Oct 2021 19:18:59 +0100 Subject: [PATCH 35/84] [Tests] Ensure MyPy type checks pass (#9284) * [Tests] Ensure MyPy type checks pass There's a few errors that come up when type checking that aren't triggering any failures: ``` Checking MyPy Type defs in the meta schedule package. python/tvm/meta_schedule/utils.py:23:1: error: Cannot find implementation or library stub for module named "psutil" python/tvm/meta_schedule/utils.py:23:1: note: See https://mypy.readthedocs.io/en/stable/running_mypy.html#missing-imports python/tvm/meta_schedule/search_strategy/search_strategy.py: note: In member "__init__" of class "MeasureCandidate": python/tvm/meta_schedule/search_strategy/search_strategy.py:59:13: error: Module has no attribute "MeasureCandidate" python/tvm/meta_schedule/search_strategy/search_strategy.py: note: In member "initialize_with_tune_context" of class "SearchStrategy": python/tvm/meta_schedule/search_strategy/search_strategy.py:83:9: error: Module has no attribute "SearchStrategyInitializeWithTuneContext" ``` To rectify this the `types-psutil` package adds type hints for `mypy` and `# type: ignore` stops `mypy` from trying to figure out types of `_ffi_api` resources. There's also a few places where variable type definitions are repeated even though they're only required once. Finally, I've ensured `task_mypy.sh` fails the build since it's stable right now, using `set -e`. * Add temporary # type : ignore for psutil --- python/gen_requirements.py | 1 + .../tvm/meta_schedule/runner/local_runner.py | 10 ++++------ .../search_strategy/replay_trace.py | 2 +- .../search_strategy/search_strategy.py | 14 ++++++------- .../task_scheduler/task_scheduler.py | 20 +++++++++---------- python/tvm/meta_schedule/utils.py | 2 +- tests/scripts/task_mypy.sh | 3 +++ 7 files changed, 27 insertions(+), 25 deletions(-) diff --git a/python/gen_requirements.py b/python/gen_requirements.py index fa94d6a64130..bdaa58e1449a 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -198,6 +198,7 @@ "sphinx_autodoc_annotation", "sphinx_gallery", "sphinx_rtd_theme", + "types-psutil", ], ), ), diff --git a/python/tvm/meta_schedule/runner/local_runner.py b/python/tvm/meta_schedule/runner/local_runner.py index e3c78d741b20..caa266f97eb3 100644 --- a/python/tvm/meta_schedule/runner/local_runner.py +++ b/python/tvm/meta_schedule/runner/local_runner.py @@ -239,13 +239,11 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: result: List[float] = future.result() error_message: str = None except TimeoutError as exception: - result: List[float] = None - error_message: str = ( - f"LocalRunner: Timeout, killed after {self.timeout_sec} seconds\n" - ) + result = None + error_message = f"LocalRunner: Timeout, killed after {self.timeout_sec} seconds\n" except Exception as exception: # pylint: disable=broad-except - result: List[float] = None - error_message: str = "LocalRunner: An exception occurred\n" + str(exception) + result = None + error_message = "LocalRunner: An exception occurred\n" + str(exception) local_future = LocalRunnerFuture(res=result, error_message=error_message) results.append(local_future) return results diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index 3afdff6de77e..15f8295f2524 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -41,7 +41,7 @@ class ReplayTrace(SearchStrategy): def __init__(self, num_trials_per_iter: int, num_trials_total: int): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.ReplayTrace, # pylint: disable=no-member + _ffi_api.ReplayTrace, # type: ignore # pylint: disable=no-member num_trials_per_iter, num_trials_total, ) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 72713155c41d..d270ea61f6dc 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -56,7 +56,7 @@ def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None: The argument information. """ self.__init_handle_by_constructor__( - _ffi_api.MeasureCandidate, # pylint: disable=no-member + _ffi_api.MeasureCandidate, # type: ignore # pylint: disable=no-member sch, args_info, ) @@ -80,7 +80,7 @@ def initialize_with_tune_context( tune_context : TuneContext The tuning context for initialization. """ - _ffi_api.SearchStrategyInitializeWithTuneContext( # pylint: disable=no-member + _ffi_api.SearchStrategyInitializeWithTuneContext( # type: ignore # pylint: disable=no-member self, tune_context ) @@ -92,11 +92,11 @@ def pre_tuning(self, design_spaces: List[Schedule]) -> None: design_spaces : List[Schedule] The design spaces for pre-tuning. """ - _ffi_api.SearchStrategyPreTuning(self, design_spaces) # pylint: disable=no-member + _ffi_api.SearchStrategyPreTuning(self, design_spaces) # type: ignore # pylint: disable=no-member def post_tuning(self) -> None: """Post-tuning for the search strategy.""" - _ffi_api.SearchStrategyPostTuning(self) # pylint: disable=no-member + _ffi_api.SearchStrategyPostTuning(self) # type: ignore # pylint: disable=no-member def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]: """Generate measure candidates from design spaces for measurement. @@ -106,7 +106,7 @@ def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]: measure_candidates : Optional[List[IRModule]] The measure candidates generated, None if finished. """ - return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # pylint: disable=no-member + return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # type: ignore # pylint: disable=no-member def notify_runner_results(self, results: List[RunnerResult]) -> None: """Update the search strategy with profiling results. @@ -116,7 +116,7 @@ def notify_runner_results(self, results: List[RunnerResult]) -> None: results : List[RunnerResult] The profiling results from the runner. """ - _ffi_api.SearchStrategyNotifyRunnerResults(self, results) # pylint: disable=no-member + _ffi_api.SearchStrategyNotifyRunnerResults(self, results) # type: ignore # pylint: disable=no-member @register_object("meta_schedule.PySearchStrategy") @@ -142,7 +142,7 @@ def f_notify_runner_results(results: List["RunnerResult"]) -> None: self.notify_runner_results(results) self.__init_handle_by_constructor__( - _ffi_api.SearchStrategyPySearchStrategy, # pylint: disable=no-member + _ffi_api.SearchStrategyPySearchStrategy, # type: ignore # pylint: disable=no-member f_initialize_with_tune_context, f_pre_tuning, f_post_tuning, diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index b8dcfd9e7a2d..f1e21ad3ddfe 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -27,7 +27,7 @@ class TaskScheduler(Object): def tune(self) -> None: """Auto-tuning.""" - _ffi_api.TaskSchedulerTune(self) # pylint: disable=no-member + _ffi_api.TaskSchedulerTune(self) # type: ignore # pylint: disable=no-member def _set_task_stopped(self, task_id: int) -> None: """Set specific task to be stopped. @@ -37,7 +37,7 @@ def _set_task_stopped(self, task_id: int) -> None: task_id : int The task id to be stopped. """ - _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # pylint: disable=no-member + _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # type: ignore # pylint: disable=no-member def _is_task_running(self, task_id: int) -> bool: """Check whether the task is running. @@ -52,7 +52,7 @@ def _is_task_running(self, task_id: int) -> bool: bool Whether the task is running. """ - return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # pylint: disable=no-member + return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member def _join_running_task(self, task_id: int) -> None: """Wait until the task is finished. @@ -62,7 +62,7 @@ def _join_running_task(self, task_id: int) -> None: task_id : int The task id to be joined. """ - _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member + _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member def _next_task_id(self) -> int: """Fetch the next task id. @@ -72,7 +72,7 @@ def _next_task_id(self) -> int: int The next task id. """ - return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member + return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member @register_object("meta_schedule.PyTaskScheduler") @@ -98,7 +98,7 @@ def f_next_task_id() -> int: return self._next_task_id() self.__init_handle_by_constructor__( - _ffi_api.TaskSchedulerPyTaskScheduler, # pylint: disable=no-member + _ffi_api.TaskSchedulerPyTaskScheduler, # type: ignore # pylint: disable=no-member f_tune, f_set_task_stopped, f_is_task_running, @@ -110,13 +110,13 @@ def tune(self) -> None: raise NotImplementedError() def _set_task_stopped(self, task_id: int) -> None: - _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # pylint: disable=no-member + _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # type: ignore # pylint: disable=no-member def _is_task_running(self, task_id: int) -> bool: - return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # pylint: disable=no-member + return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member def _join_running_task(self, task_id: int) -> None: - _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member + _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member def _next_task_id(self) -> int: - return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member + return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index bf2ef17fb308..c79137d55dda 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -20,7 +20,7 @@ import shutil from typing import Any, Callable, List, Optional, Union -import psutil +import psutil # type: ignore import tvm from tvm._ffi import get_global_func, register_func from tvm.error import TVMError diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index ecc8ba5d17b0..aba4663d5931 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -15,6 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +set -e +set -u set -o pipefail echo "Checking MyPy Type defs in the TensorIR schedule package." From 08018eac9e2ce235a230b17dbdb1c1c8f2422798 Mon Sep 17 00:00:00 2001 From: Yuanjing Shi Date: Thu, 14 Oct 2021 16:49:15 -0700 Subject: [PATCH 36/84] [TIR] Add support for 0-dim buffer (#9224) --- src/te/operation/create_primfunc.cc | 6 -- src/tir/ir/script/script_complete.cc | 2 +- .../unittest/test_tvmscript_complete.py | 29 +++++++++ tests/python/unittest/test_tvmscript_ops.py | 59 +++++++++++++++++++ 4 files changed, 89 insertions(+), 7 deletions(-) diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index a47556bac101..aa164b03a2a7 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -102,12 +102,6 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te:: f_push_block_vars(compute_op->axis); f_push_block_vars(compute_op->reduce_axis); - // If we have a rank 0 tensor then we manifest it as a rank 1 buffer with a single element. - if (compute_op->axis.size() == 0) { - iter_vars.push_back(IterVar(Range::FromMinExtent(0, 1), Var(), IterVarType::kDataPar)); - bindings.push_back(Var()); - } - // Step 2. Declare buffer and update op2buffers Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); info->tensor2buffers[tensor] = buffer; diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index f265a8ae2b1b..3c4604d18a0d 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -122,7 +122,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { // generate surrounding loops automatically Stmt res = script_completer(func->body); // generate root block automatically - if (script_completer.contains_block && !contain_root) { + if ((script_completer.contains_block || root_allocates.size()) && !contain_root) { res = Block({}, {}, {}, "root", res, NullOpt, root_allocates); res = BlockRealize({}, Bool(true), Downcast(res)); } diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 7c521db21bb8..4704b27fa5fa 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -272,6 +272,34 @@ def test_complete_match_buffer(): tvm.ir.assert_structural_equal(match_buffer_func, expected_match_buffer_func) +@T.prim_func +def alloc_buffer_func(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [2, 2], dtype="float32") + B = T.match_buffer(b, [2, 2], dtype="float32") + C = T.alloc_buffer([2, 2], dtype="float32") + A[(0, 0)] = T.float32(2) + C[(0, 0)] = A[(0, 0)] + B[(0, 0)] + B[(0, 0)] = C[(0, 0)] + + +@T.prim_func +def expect_alloc_buffer_func(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1) + B = T.match_buffer(b, [2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1) + with T.block([], "root"): + T.reads([]) + T.writes([]) + C = T.alloc_buffer([2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1) + A[(0, 0)] = T.float32(2) + C[(0, 0)] = A[(0, 0)] + B[(0, 0)] + B[(0, 0)] = C[(0, 0)] + + +def test_complete_alloc_buffer(): + rt_func = tvm.script.from_source(alloc_buffer_func.script(show_meta=True)) + tvm.ir.assert_structural_equal(alloc_buffer_func, expect_alloc_buffer_func) + + if __name__ == "__main__": test_complete_matmul() test_complete_matmul_original() @@ -279,3 +307,4 @@ def test_complete_match_buffer(): test_complete_part_region() test_complete_buffer_indices() test_complete_match_buffer() + test_complete_alloc_buffer() diff --git a/tests/python/unittest/test_tvmscript_ops.py b/tests/python/unittest/test_tvmscript_ops.py index c55fd7b69282..0aa043e09022 100644 --- a/tests/python/unittest/test_tvmscript_ops.py +++ b/tests/python/unittest/test_tvmscript_ops.py @@ -101,5 +101,64 @@ def test_get_valid_counts_script_func(): _check_get_valid_counts_with_numpy(f, (1, 2500, 6), 0.0, 0, 1) +@T.prim_func +def alloc_zero_dim_buffer(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [], dtype="float32") + B = T.match_buffer(b, [], dtype="float32") + # body + # tir.with block("root") + C = T.alloc_buffer([], dtype="float32") + A[()] = T.float32(2) + C[()] = A[()] + B[()] + B[()] = C[()] + + +@T.prim_func +def alloc_zero_dim_buffer_block(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + B = T.match_buffer(b, (), "float32") + with T.block([], "root"): + T.reads([]) + T.writes([]) + C = T.alloc_buffer((), "float32") + A[()] = T.float32(2) + C[()] = A[()] + B[()] + B[()] = C[()] + + +def _check_alloc_zero_dim_buffer(f): + dtype = "float32" + ctx = tvm.cpu() + + np_data = np.zeros(shape=()).astype(dtype) + np_out = np.zeros(shape=()).astype(dtype) + tvm_data = tvm.nd.array(np_data, ctx) + tvm_out = tvm.nd.array(np_out, ctx) + + # np func exection + np_inter = np.array(1) + np_data[()] = 2.0 + np_inter[()] = np_data[()] + np_out[()] + np_out[()] = np_inter[()] + + # tvm func execution + f(tvm_data, tvm_out) + tvm.testing.assert_allclose(tvm_out.numpy(), np_out, rtol=1e-5) + + +def test_alloc_zero_dim_buffer_round_trip(): + func = alloc_zero_dim_buffer + func_with_block = alloc_zero_dim_buffer_block + rt_func = tvm.script.from_source(func.script(show_meta=True)) + rt_func_with_block = tvm.script.from_source(func_with_block.script(show_meta=True)) + rt_mod = tvm.build(rt_func, "llvm") + rt_mod_with_block = tvm.build(rt_func_with_block, "llvm") + tvm.ir.assert_structural_equal(func, func_with_block) + tvm.ir.assert_structural_equal(rt_func, rt_func_with_block) + _check_alloc_zero_dim_buffer(rt_mod) + _check_alloc_zero_dim_buffer(rt_mod_with_block) + + if __name__ == "__main__": test_get_valid_counts_script_func() + test_alloc_zero_dim_buffer_round_trip() From f4db899ab52ab81575762872b7560b5fdf909428 Mon Sep 17 00:00:00 2001 From: Philipp van Kempen Date: Fri, 15 Oct 2021 01:55:25 +0200 Subject: [PATCH 37/84] [TFLite] Add option to overwrite OperatorConverter class in relay.frontend.from_tflite (#9256) * [TFLite] Relay Frontend: Add option to overwrite OperatorConverter class This allows to overwrite the mapping from TFLite Operators to TVM Relay Operators from external python scripts. This has the following advantages: - Adding support for unsupported builtin or even custom operators by adding a hand-written convert function - Enables overwriting of existing convert functions for supported operators by alternative implementations (useful for currently unsupported edge cases) Example Usage: ``` class CustomOperatorConverter(relay.frontend.tflite.OperatorConverter): def __init__(self, model, subgraph, exp_tab): super(CustomOperatorConverter, self).__init__(model, subgraph, exp_tab) convert_map_overwrite = {"SUB": self.convert_sub_custom} self.convert_map.update(convert_map_overwrite) def convert_sub_custom(self, op): ... ... relay_mod = relay.frontend.from_tflite( tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict, op_converter=CustomOperatorConverter ) ``` [TFLite] Make sure that even DETECTION_POSTPROCESS op can be overwritten This is desirable, because the current implementation of this CUSTOM op is incompatible with MicroTVM targets * Tests: added test case for overwriting op_converter in TFLite relay frontend Kept the test as simple as possible by only comparing 2 different implementations of a SUB TFLite operator: 1. Original: c = a - b 2. Dummy: c = a + (-b) Comparison with TFLite reference output is not necessary because tis is already covered by other test cases. Instead comparisons of the two TVM models are used. --- python/tvm/relay/frontend/tflite.py | 9 ++- tests/python/frontend/tflite/test_forward.py | 72 +++++++++++++++++++- 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index a66fc4736a98..3688ff5ff4e5 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -66,6 +66,7 @@ def __init__(self, model, subgraph, exp_tab): self.activation_fn_type = build_str_map(ActivationFunctionType()) self.builtin_options = build_str_map(BuiltinOptions()) self.prefetched_nodes = {} + self.allow_custom_ops = False # Add more operators self.convert_map = { @@ -287,6 +288,10 @@ def get_op_code_str(self, op): if op_code_id == BuiltinOperator.CUSTOM: # Custom operator custom_op_code_str = self.model.OperatorCodes(op_code_list_idx).CustomCode() + + if self.allow_custom_ops: + return "CUSTOM" + if custom_op_code_str == b"TFLite_Detection_PostProcess": return "DETECTION_POSTPROCESS" @@ -3695,7 +3700,7 @@ def _input_type(model): return shape_dict, dtype_dict -def from_tflite(model, shape_dict=None, dtype_dict=None): +def from_tflite(model, shape_dict=None, dtype_dict=None, op_converter=OperatorConverter): """Convert from tflite model into compatible relay Function. Parameters @@ -3755,7 +3760,7 @@ def from_tflite(model, shape_dict=None, dtype_dict=None): exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype)) # op code in model - op_converter = OperatorConverter(model, subgraph, exp_tab) + op_converter = op_converter(model, subgraph, exp_tab) op_converter.check_unsupported_ops() op_converter.convert_op_to_relay() diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 4a6f88417b9c..754976ca8c13 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -161,6 +161,7 @@ def run_tvm_graph( target="llvm", out_names=None, mode="graph_executor", + op_converter=relay.frontend.tflite.OperatorConverter, ): """Generic function to compile on relay and execute on tvm""" # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 @@ -185,7 +186,7 @@ def run_tvm_graph( dtype_dict[e] = input_data[i].dtype.name mod, params = relay.frontend.from_tflite( - tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict + tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict, op_converter=op_converter ) if mode in ["debug", "vm"]: @@ -3996,6 +3997,72 @@ def test_detection_postprocess(): ) +####################################################################### +# Custom Converter +# ---------------- + + +def test_custom_op_converter(): + """Test case for user-defined operator converter in TFLite frontend""" + + class DummyOperatorConverter(relay.frontend.tflite.OperatorConverter): + """Operator Converter for converting TFLite ops to relay ops""" + + def __init__(self, model, subgraph, exp_tab): + super(DummyOperatorConverter, self).__init__(model, subgraph, exp_tab) + self.allow_custom_ops = True + + convert_map_overwrite = {"SUB": self.convert_sub_dummy} + + self.convert_map.update(convert_map_overwrite) + + def convert_sub_dummy(self, op): + """Convert TFLite SUB""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + lhs_tensor = input_tensors[0] + rhs_tensor = input_tensors[1] + + lhs_expr = self.get_expr(lhs_tensor.tensor_idx) + rhs_expr = self.get_expr(rhs_tensor.tensor_idx) + + temp_expr = relay.op.negative(rhs_expr) + out = relay.op.add(lhs_expr, temp_expr) + + return out + + with tf.Graph().as_default(): + # Generate TFLite model for single addition + data = [ + np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), + np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)), + ] + in_data = [ + array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in_0"), + array_ops.placeholder(shape=data[1].shape, dtype="float32", name="in_1"), + ] + out = math_ops.subtract(in_data[0], in_data[1]) + in_name = [x[1] for x in zip(in_data, ("in_0:0", "in_1:0"))] + input_tensors = [x for x in in_data] + output_tensors = [out] + in_node = [0] * len(in_name) + for i in range(len(in_name)): + in_node[i] = in_name[i].split(":")[0] if ":" in in_name[i] else in_name[i] + + with tf.Session() as sess: + converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors) + tflite_model_buf = converter.convert() + in_data = [x[1] for x in zip(in_data, data)] + tvm_output_orig = run_tvm_graph(tflite_model_buf, in_data, in_node) + tvm_output_dummy = run_tvm_graph( + tflite_model_buf, in_data, in_node, op_converter=DummyOperatorConverter + ) + tvm.testing.assert_allclose( + np.squeeze(tvm_output_orig[0]), np.squeeze(tvm_output_dummy[0]), rtol=1e-5, atol=1e-5 + ) + + ####################################################################### # Mobilenet # --------- @@ -4621,6 +4688,9 @@ def test_prevent_tensorflow_dynamic_range(): # Detection_PostProcess test_detection_postprocess() + # Overwrite Converter + test_custom_op_converter() + # End to End test_forward_mobilenet_v1() test_forward_mobilenet_v2() From acff61c8554ff6e52fa8340f559f65786b5f832e Mon Sep 17 00:00:00 2001 From: Jason Date: Fri, 15 Oct 2021 11:20:10 +0800 Subject: [PATCH 38/84] [Frontend][PaddlePaddle] Remove unused parameters and fix doc string (#9283) * add part of operators * remove part of operators * add lookup * add test * Update paddlepaddle.py * modify error message for SAME padding * Remove some function and old version operator * Remove some function and old version operator * Remove some function and old version operator * Remove some function and old version operator * add dot test * modify doc * remove unreviewed code * Update paddlepaddle.py * Update test_forward.py * Update paddlepaddle.py * Update paddlepaddle.py * Update test_forward.py * Update test_forward.py * add more cases for tests * add more cases for tests * remove annotation * reduce test case sizes * Remove unused parameters and fix doc string for paddle frontend * remove blank line * fix code error * modify test_forward.py --- python/tvm/relay/frontend/paddlepaddle.py | 33 +++++++++++++++++-- .../frontend/paddlepaddle/test_forward.py | 3 +- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 378002a74416..c32449546f77 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -23,6 +23,7 @@ import tvm from tvm.ir import IRModule +from ... import nd as _nd from .. import analysis from .. import ty as _ty from .. import expr as _expr @@ -954,10 +955,12 @@ def extract_parameters(self, program, scope=None): if not var.persistable: continue if isinstance(scope, dict): - self.params[name] = scope[name] + self.params[name] = _nd.array(scope[name]) else: - self.params[name] = np.array(scope.var(name).get_tensor()) - self.nodes[name] = _expr.const(self.params[name]) + self.params[name] = _nd.array(np.array(scope.var(name).get_tensor())) + shape = self.params[name].shape + dtype = self.params[name].dtype + self.nodes[name] = new_var(name, shape=shape, dtype=dtype) def check_input_shape(self, op, block): """Check the shape information of model's inputs, fixed shape is recommended.""" @@ -1048,6 +1051,12 @@ def from_translated_layer(self, layer, shape_dict): free_vars = analysis.free_vars(outputs) func = _function.Function(free_vars, outputs) mod = IRModule.from_expr(func) + # remove unused parameters + final_params = dict() + for var in free_vars: + if var.name_hint in self.params: + final_params[var.name_hint] = self.params[var.name_hint] + self.params = final_params return mod, self.params @@ -1056,6 +1065,24 @@ def from_paddle(program_or_layer, shape_dict=None, scope=None): PaddlePaddle Program/TranslatedLayer represent the computation graph of PaddlePaddle model, and PaddlePaddle scope stores all the weights of PaddlePaddle model. + + Parameters + ---------- + program_or_layer : object of `paddle.static.Program` or `paddle.jit.TranslatedLayer` + Loaded model by `paddle.static.load_inference_model` or `paddle.jit.load` + + shape_dict : dict of str to tuple/list, optional + The input shape of model + + scope : object of `paddle.static.Scope`, optional + The scope that saves all the weights of model, use `paddle.static.global_scope` by default + + Returns + ------- + mod : tvm.IRModule + The relay module for compilation + + params : dict of str to tvm.nd.NDArray """ import paddle diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index 1d64f947e68a..e3d1fc9daf2b 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -80,9 +80,8 @@ def verify_model(func, input_data, rtol=1e-5, atol=1e-5): baseline_outputs = (baseline_outputs.numpy(),) mod, params = relay.frontend.from_paddle(baseline_model, input_shape_dict) - parms_num = min(len(input_names), len(mod["main"].params)) compiled_names = [] - for arg in mod["main"].params[:parms_num]: + for arg in mod["main"].params: assert arg.name_hint in input_names or arg.name_hint in params if arg.name_hint in input_names: compiled_names.append(arg.name_hint) From c00ce37b71f30a77bffaeef47bc12239bc2b2f3c Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 14 Oct 2021 20:20:29 -0700 Subject: [PATCH 39/84] [Profiler] Do not aggregate frames with different devices (#9290) --- src/runtime/profiling.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index bd59be87f7d9..d8cea11dc078 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -311,6 +311,9 @@ String ReportNode::AsTable(bool sort, bool aggregate) const { if (frame.find("Argument Shapes") != frame.end()) { name += Downcast(frame["Argument Shapes"]); } + if (frame.find("Device") != frame.end()) { + name += Downcast(frame["Device"]); + } if (aggregates.find(name) == aggregates.end()) { aggregates[name] = {i}; From afcf80c581bcd556dcb7a2814bea2c2eecc35f6e Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 15 Oct 2021 22:43:27 -0500 Subject: [PATCH 40/84] [Hexagon] Fix addressing TVMValue array (#9302) --- src/target/llvm/codegen_hexagon.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index bd22532d998c..5414366c1cd7 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -644,7 +644,7 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::V } else { ICHECK(t.is_handle()); buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); - buf = builder_->CreateInBoundsGEP(t_void_p_, buf, index); + buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index); return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); } } From 39d32c35dde7ea499ebbf6d6475cacb569c7a114 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 15 Oct 2021 20:46:17 -0700 Subject: [PATCH 41/84] [Profiler] Sort columns in table and csv output (#9300) This should fix some problems around flakey profiler tests based on order of columns. --- src/runtime/profiling.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index d8cea11dc078..8b37bfcd539d 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -181,7 +181,7 @@ String ShapeString(const std::vector& shapes) { String ReportNode::AsCSV() const { // get unique headers - std::unordered_set unique_headers; + std::set unique_headers; for (auto row : calls) { for (auto p : row) { @@ -407,7 +407,7 @@ String ReportNode::AsTable(bool sort, bool aggregate) const { } // Table formatting - std::unordered_set unique_headers; + std::set unique_headers; for (auto row : aggregated_calls) { for (auto p : row) { @@ -415,10 +415,11 @@ String ReportNode::AsTable(bool sort, bool aggregate) const { } } - std::vector headers = {"Name", "Duration (us)", - "Percent"}; // always include these headers + // always include these headers in this order + std::vector headers = {"Name", "Duration (us)", "Percent", + "Device", "Count", "Argument Shapes"}; for (auto header : unique_headers) { - if (header != "Name" && header != "Duration (us)" && header != "Percent") { + if (std::find(headers.begin(), headers.end(), header) == headers.end()) { headers.push_back(header); } } From c279b94a659837ca90b83ee1f5b8d500a4f6d5ac Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Fri, 15 Oct 2021 23:01:01 -0500 Subject: [PATCH 42/84] [TE] Light refactoring of TE -> TIR paths. (#9263) * [TE] Light refactoring of TE -> TIR paths. - Added ScheduleToPrimFunc, extracting out common behavior in ScheduleToModule and auto_scheduler's feature extraction. - Added `tvm.driver.build_module.schedule_to_module`, to avoid needing to 4-line boilerplate needed to do so. Also makes deviations from the usual path (e.g. `debug_keep_trivial_loop`) much more explicit. * Removed schedule_to_primfunc, replaced usage with schedule_to_module. * Returned C++ function ScheduleToPrimfunc to be inside ScheduleToModule. --- python/tvm/autotvm/feature.py | 10 ++-- python/tvm/driver/build_module.py | 5 ++ .../backend/contrib/ethosu/tir/compiler.py | 15 ++---- src/auto_scheduler/feature.cc | 38 ++++----------- src/driver/driver_api.cc | 19 ++++---- tests/python/integration/test_reduce.py | 8 ++-- tests/python/unittest/test_te_schedule_ops.py | 26 ++++------ .../test_tir_transform_inject_copy_intrin.py | 18 ++----- .../test_tir_transform_make_packed_api.py | 9 ++-- ...merge_dynamic_shared_memory_allocations.py | 13 ++--- .../test_tir_transform_narrow_datatype.py | 9 ++-- .../test_tir_transform_storage_flatten.py | 13 ++--- .../test_tir_transform_storage_rewrite.py | 47 ++++--------------- 13 files changed, 75 insertions(+), 155 deletions(-) diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index 8d2591dce50b..f73c65fbd1d8 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -31,7 +31,6 @@ import tvm._ffi from tvm.target import Target -from tvm.te import schedule from tvm.driver import build_module @@ -39,13 +38,12 @@ def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ - binds, _ = build_module.get_binds(args, compact=False, binds=binds) sch = sch.normalize() # Phase 0 - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds, True) - func = schedule.SchedulePostProcToPrimFunc(args, stmt, None) - mod = tvm.IRModule.from_expr(func._move()) + context = tvm.transform.PassContext(config={"tir.debug_keep_trivial_loop": True}) + with context: + mod = build_module.schedule_to_module(sch, args, binds=binds) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) assert simple_mode diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 429b3e1727cc..29fff775150f 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -67,6 +67,11 @@ def schedule_to_module( binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, ) -> IRModule: """According to the given schedule, form a function. + + This is a low-level function intended for testing purposes, and + does not apply any optimization passes. In general, `tvm.lower` + and `tvm.build` should be used instead. + Parameters ---------- sch : tvm.te.schedule.Schedule diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index 3283e0515c72..c792ade06643 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -19,7 +19,7 @@ import tvm from tvm import relay from tvm.relay.expr_functor import ExprMutator -from tvm.driver.build_module import get_binds +from tvm.driver.build_module import schedule_to_module from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants from .scheduler import schedule @@ -64,22 +64,17 @@ def lower_ethosu(sch, args, const_dict, name="main"): "no_unroll_loop_with_extent_one": True, }, "tir.UnrollLoop": {"auto_max_depth": -1}, + "tir.noalias": True, + "tir.debug_keep_trivial_loop": True, } # Merge two configs curr_cfg = {**curr_cfg, **tir_compiler_cfg} sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True) - compact = tvm.te.schedule.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, None) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) - - func = func.with_attr("global_symbol", name) - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: func}) with tvm.transform.PassContext(config=curr_cfg): + mod = schedule_to_module(sch, args, name) + mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.UnrollLoop()(mod) diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index be78bc4aa9f9..aaf7d48b10c5 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -44,13 +45,6 @@ #include "search_policy/utils.h" #include "utils.h" -namespace tvm { -// import the function from driver_api.cc -void GetBinds(const Array& args, bool compact, - const std::unordered_map& binds, - Map* out_binds, Array* out_arg_list); -} // namespace tvm - namespace tvm { namespace auto_scheduler { @@ -1268,35 +1262,25 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i Array tensors; std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); + + // When inlining, replace const matrices with const values. + // Produces wrong IR, but good enough for feature extraction, and + // can improve the speed of feature extraction/search. Must be + // called before ScheduleToModule to have an effect. sch = sch.normalize_for_feature_extraction(); - auto bounds = te::InferBound(sch); try { - auto stmt = te::ScheduleOps(sch, bounds, false); - Map out_binds; - Array out_arg_list; - bool compact = te::VerifyCompactBuffer(stmt); const std::string& name = "main"; - GlobalVar global_var(name); - - // Copied from driver_api.cc::lower auto pass_ctx = tvm::transform::PassContext::Current(); - GetBinds(tensors, compact, std::unordered_map(), &out_binds, - &out_arg_list); - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); - bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + auto mod = ScheduleToModule(sch, Array{tensors.begin(), tensors.end()}, name, + std::unordered_map()); + bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", Bool(true)); - } - auto mod = IRModule(Map({{global_var, f}})); - if (IsGPUTask(task)) { auto pass_list = Array(); // Phase 0 @@ -1323,9 +1307,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i const auto& optimize = tir::transform::Sequential(Array{tir::transform::Simplify()}); mod = optimize(std::move(mod)); - const auto& it = mod->functions.find(global_var); - ICHECK(it != mod->functions.end()); - const auto& prim_func = (*it).second.as(); + PrimFunc prim_func = Downcast(mod->Lookup(name)); GetPerStoreFeature(prim_func->body, task->hardware_params->cache_line_bytes, max_n_bufs, feature); } catch (Error& e) { diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e659421c23c4..2d57d6e30b45 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); using runtime::PackedFunc; using runtime::TVMArgs; @@ -287,24 +288,24 @@ IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { return mod; } +// Convert te schedule to IRModule IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds) { - // Convert te schedule to IRModule - Array out_arg_list; - transform::PassContext pass_ctx = transform::PassContext::Current(); - sch = sch.normalize(); + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool debug_keep_trivial_loop = + pass_ctx->GetConfig("tir.debug_keep_trivial_loop", Bool(false)).value(); + // Before TIR transformation. - Map bounds = te::InferBound(sch); - tir::Stmt stmt = te::ScheduleOps(sch, std::move(bounds), false); + tir::Stmt stmt = te::ScheduleOps(sch, te::InferBound(sch), debug_keep_trivial_loop); bool compact = te::VerifyCompactBuffer(stmt); Map out_binds; + Array out_arg_list; GetBinds(args, compact, binds, &out_binds, &out_arg_list); - // Build the function - // At this point binds is only te::Tensors + // Build the function, converting from te::Tensor to tir::Buffer tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); @@ -325,7 +326,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") const Map& binds) { std::unordered_map c_binds; // Check to make sure binds is not null before doing the conversion; - if (binds.get() != nullptr) { + if (binds.defined()) { for (auto kv : binds) { c_binds.insert({kv.first, kv.second}); } diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index ca097734a9eb..a40164ded941 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. import pytest +import numpy as np import tvm from tvm import te, topi -import numpy as np +from tvm.driver.build_module import schedule_to_module import tvm.testing import tvm.topi.testing @@ -532,10 +533,7 @@ def test_reduce_storage_reuse(): target = tvm.target.Target("cuda") def run_passes(sch, args): - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(sch, args) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) return tvm.transform.Sequential( [ diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index bc4bc4f56e19..ca3ab3aade98 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np + import tvm from tvm import te -import numpy as np +from tvm.driver.build_module import schedule_to_module def test_schedule0(): @@ -26,11 +28,8 @@ def test_schedule0(): A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") s = te.create_schedule(A1.op) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None) - assert isinstance(func, tvm.tir.PrimFunc) + mod = schedule_to_module(s, [A, A1]) + assert isinstance(mod["main"], tvm.tir.PrimFunc) def test_schedule1(): @@ -42,12 +41,9 @@ def test_schedule1(): s = te.create_schedule(A1.op) xo, xi = s[A1].split(A1.op.axis[0], 8) s[A1].pragma(xo, "auto_unroll_max_step", 10) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A1], stmt, None) - assert isinstance(func, tvm.tir.PrimFunc) + mod = schedule_to_module(s, [A, A1]) + assert isinstance(mod["main"], tvm.tir.PrimFunc) def test_schedule2(): @@ -60,11 +56,9 @@ def test_schedule2(): s = te.create_schedule(A2.op) xo, xi = s[A2].split(A2.op.axis[0], 8) s[A1].compute_at(s[A2], xo) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) - assert isinstance(func, tvm.tir.PrimFunc) + + mod = schedule_to_module(s, [A, A2]) + assert isinstance(mod["main"], tvm.tir.PrimFunc) def test_schedule_scan(): diff --git a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py index 86bf87d5fa85..aa0448c3c682 100644 --- a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py +++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py @@ -17,6 +17,7 @@ import tvm import tvm.testing from tvm import te +from tvm.driver.build_module import schedule_to_module def test_copy2d(): @@ -53,11 +54,7 @@ def test_copy_pad(): ) s = te.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): @@ -77,11 +74,7 @@ def test_single_point_test(): B = te.compute((1,), lambda i: A[i], name="B") s = te.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): @@ -105,11 +98,8 @@ def test_copy_pad_split(): xo, xi = s[B].split(B.op.axis[0], factor=4) s[Apad].compute_at(s[B], xo) s[Apad].pragma(s[Apad].op.axis[0], "memcpy") - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 15f994069abd..1ab6bdaad90a 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy + import tvm from tvm import te -import numpy +from tvm.driver.build_module import schedule_to_module def test_makeapi(): @@ -27,10 +29,7 @@ def test_makeapi(): C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") s = te.create_schedule(C.op) - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([n, A, B, C], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [n, A, B, C]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Apply( lambda f: f.with_attr( diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 9c511f1de6b9..cc78b84f9b4e 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -14,20 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te import numpy as np + +import tvm import tvm.testing +from tvm import te +from tvm.driver.build_module import schedule_to_module from tvm.topi.math import cast def run_passes(sch, args): - bounds = tvm.te.schedule.InferBound(sch) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(sch, args) return tvm.transform.Sequential( [ tvm.tir.transform.StorageFlatten(64), diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index cb8968cfc880..b5620d748d8a 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te -from tvm import relay +from tvm import te, relay +from tvm.driver.build_module import schedule_to_module from tvm.tir import const @@ -39,11 +39,8 @@ def lower_sch(sch, args, target_bits): else: raise ValueError("args must be Tensor, Buffer or Var") sch = sch.normalize() - bounds = te.schedule.InferBound(sch) - stmt = te.schedule.ScheduleOps(sch, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(sch, args) mod = tvm.tir.transform.StorageFlatten(64)(mod) return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 37223493a8b5..a51e926155d3 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.driver.build_module import schedule_to_module from tvm.script import tir as T from tvm.relay import GlobalVar @@ -30,14 +31,10 @@ def test_flatten2(): s = te.create_schedule(A2.op) xo, xi = s[A2].split(A2.op.axis[0], 8) s[A1].compute_at(s[A2], xo) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="A") A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name="A2") - func = tvm.te.schedule.SchedulePostProcToPrimFunc([Ab, A2b], stmt, {A: Ab, A2: A2b}) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [Ab, A2b], binds={A: Ab, A2: A2b}) mod = tvm.tir.transform.StorageFlatten(64)(mod) @@ -70,12 +67,8 @@ def test_flatten_storage_align(): s = te.create_schedule(A2.op) s[A1].storage_align(A1.op.axis[0], 2, 1) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, A2]) mod = tvm.transform.Sequential( [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] )(mod) diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 9e738b136b17..5a91788283d6 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.driver.build_module import schedule_to_module def test_storage_share(): @@ -28,12 +29,7 @@ def test_storage_share(): B = te.compute((m, l), lambda i, j: B[i, j] + (t + 1), name="A%d" % t) s = te.create_schedule(B.op) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -169,12 +165,7 @@ def test_inplace_rule(): AA = te.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name="AA") B = te.compute((m,), lambda i: AA[i] + 1, name="B") s = te.create_schedule(B.op) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -206,11 +197,8 @@ def test_storage_combine(): s = te.create_schedule(B.op) for S in stages[:-1]: s[S].set_scope("global:tag") - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + + mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -238,10 +226,7 @@ def test_storage_combine_with_vectorization(): BB = s.cache_read(B, "global:tag", readers=[C]) CC = s.cache_write(C, "global:tag") s[CC].vectorize(s[CC].op.axis[0]) - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B, C]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.VectorizeLoop()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) @@ -285,11 +270,7 @@ def test_storage_share_gpu(): s[A[2 * t + 1]].compute_at(s[A[2 * t + 2]], tx) s[A[2 * t + 1]].set_scope("shared") - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A[0], A[-1]], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A[0], A[-1]]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) @@ -418,12 +399,7 @@ def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): A0L = s.cache_read(A0, scope_tb, [A2]) A1L = s.cache_read(A1, scope_tb, [A2]) A2L = s.cache_read(A2, scope_tb, [B]) - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C, D], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [A, B, C, D]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -511,12 +487,7 @@ def test_inplace_rule3(): s[B10].compute_inline() s = s.normalize() - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - func = tvm.te.schedule.SchedulePostProcToPrimFunc([B0, B1, B2, B3, B4, B5, B], stmt, None) - mod = tvm.IRModule.from_expr(func) + mod = schedule_to_module(s, [B0, B1, B2, B3, B4, B5, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) From 236e4c7f3ef2c7c9970b371afb986f20cc38a845 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Sun, 17 Oct 2021 00:13:45 +0300 Subject: [PATCH 43/84] [iOS] Fix build issues on the latest XCode and iOS (#9298) 1. Specify target for compiled dylib 2. Specify Metal Shader Language version when we compile metal library. --- python/tvm/contrib/xcode.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/xcode.py b/python/tvm/contrib/xcode.py index c44a2fe4a136..6d5e10f611db 100644 --- a/python/tvm/contrib/xcode.py +++ b/python/tvm/contrib/xcode.py @@ -45,7 +45,23 @@ def xcrun(cmd): return out.strip() -def create_dylib(output, objects, arch, sdk="macosx"): +def __get_min_os_version(sdk): + if sdk in ("macosx", "iphonesimulator"): + return None + if sdk == "iphoneos": + return "13.0" + raise RuntimeError("Unsupported sdk: %s" % sdk) + + +def __get_min_os_version_cmd(sdk, min_os_version): + if min_os_version is None: + min_os_version = __get_min_os_version(sdk) + if min_os_version is not None: + return "-mios-version-min=" + min_os_version + return "" + + +def create_dylib(output, objects, arch, sdk="macosx", min_os_version=None): """Create dynamic library. Parameters @@ -71,6 +87,7 @@ def create_dylib(output, objects, arch, sdk="macosx"): cmd += ["-dynamiclib"] cmd += ["-arch", arch] cmd += ["-isysroot", sdk_path] + cmd += [__get_min_os_version_cmd(sdk, min_os_version)] cmd += ["-o", output] if isinstance(objects, str): cmd += [objects] @@ -90,7 +107,7 @@ def create_dylib(output, objects, arch, sdk="macosx"): create_dylib.output_format = "dylib" -def compile_metal(code, path_target=None, sdk="macosx"): +def compile_metal(code, path_target=None, sdk="macosx", min_os_version=None): """Compile metal with CLI tool from env. Parameters @@ -123,7 +140,14 @@ def compile_metal(code, path_target=None, sdk="macosx"): # # xcrun -sdk macosx metal -c MyLibrary.metal -o MyLibrary.air # xcrun -sdk macosx metallib MyLibrary.air -o MyLibrary.metallib - cmd1 = ["xcrun", "-sdk", sdk, "metal", "-O3"] + min_target = __get_min_os_version_cmd(sdk, min_os_version) + if sdk == "macosx": + language_version = "-std=macos-metal2.3" + elif sdk in ("iphoneos", "iphonesimulator"): + language_version = "-std=ios-metal2.3" + else: + raise RuntimeError("Unsupported sdk: %s" % sdk) + cmd1 = ["xcrun", "-sdk", sdk, "metal", language_version, min_target, "-O3"] cmd1 += ["-c", temp_code, "-o", temp_ir] cmd2 = ["xcrun", "-sdk", sdk, "metallib"] cmd2 += [temp_ir, "-o", file_target] From f5eb4c2d08edbb09d1971ba85462c1df90acb07c Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Sun, 17 Oct 2021 00:15:30 +0300 Subject: [PATCH 44/84] Rename build helper (#9297) --- python/tvm/driver/build_module.py | 2 +- src/driver/driver_api.cc | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 29fff775150f..5ec44c6d6ed1 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -260,7 +260,7 @@ def build( target_input_mod, target_host ) - rt_mod_host = _driver_ffi.finalize_module(target_input_mod, target_host) + rt_mod_host = _driver_ffi.preprocess_module(target_input_mod, target_host) target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 2d57d6e30b45..24cae798988e 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -426,7 +426,8 @@ std::pair SplitMixedModule(IRModule mod_mixed, const Target& return {host_mod, device_mod}; } -runtime::Module FinalizeModule(const Map& inputs_arg, const Target& host_target) { +runtime::Module PreProcessModuleForBuild(const Map& inputs_arg, + const Target& host_target) { std::vector device_modules; Map inputs = inputs_arg; Target target_host = host_target; @@ -480,9 +481,9 @@ runtime::Module FinalizeModule(const Map& inputs_arg, const Ta return complete_mod; } -TVM_REGISTER_GLOBAL("driver.finalize_module") +TVM_REGISTER_GLOBAL("driver.preprocess_module") .set_body_typed([](const Map& inputs_arg, Target host_target) { - return FinalizeModule(inputs_arg, host_target); + return PreProcessModuleForBuild(inputs_arg, host_target); }); runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { From 2b06ab312fb27c7eb567dfd128a7ee9470e7809e Mon Sep 17 00:00:00 2001 From: lixiaoquan Date: Sun, 17 Oct 2021 17:19:05 +0800 Subject: [PATCH 45/84] [TVMC] Support dot inside of TVMC input shape name arguments (#9294) * [TVMC] Support dot inside of TVMC input shape name arguments * dot -> dots --- python/tvm/driver/tvmc/common.py | 5 +++-- tests/python/driver/tvmc/test_shape_parser.py | 7 +++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index f4bc3ec027d7..1ee24cf69d44 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -418,7 +418,7 @@ def parse_shape_string(inputs_string): ---------- inputs_string: str A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that - indicates the desired shape for specific model inputs. Colons and forward slashes + indicates the desired shape for specific model inputs. Colons, forward slashes and dots within input_names are supported. Spaces are supported inside of dimension arrays. Returns @@ -432,7 +432,8 @@ def parse_shape_string(inputs_string): # * Spaces inside arrays # * forward slashes inside names (but not at the beginning or end) # * colons inside names (but not at the beginning or end) - pattern = r"(?:\w+\/)?[:\w]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" + # * dots inside names + pattern = r"(?:\w+\/)?[:\w.]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" input_mappings = re.findall(pattern, inputs_string) if not input_mappings: raise argparse.ArgumentTypeError( diff --git a/tests/python/driver/tvmc/test_shape_parser.py b/tests/python/driver/tvmc/test_shape_parser.py index c021078630ed..f49d89ac7c0f 100644 --- a/tests/python/driver/tvmc/test_shape_parser.py +++ b/tests/python/driver/tvmc/test_shape_parser.py @@ -94,3 +94,10 @@ def test_invalid_colon(): def test_invalid_slashes(shape_string): with pytest.raises(argparse.ArgumentTypeError): tvmc.common.parse_shape_string(shape_string) + + +def test_dot(): + # Check dot in input name + shape_string = "input.1:[10,10,10]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input.1": [10, 10, 10]} From 97f996cafcb1e4775cd3b77baa103f15906c0401 Mon Sep 17 00:00:00 2001 From: qingchao Date: Sun, 17 Oct 2021 18:47:23 +0800 Subject: [PATCH 46/84] fix typo (#9304) --- python/tvm/relay/op/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index e47928919ce1..e615bbf21b86 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1070,7 +1070,7 @@ def fixed_point_multiply(data, multiplier, shift): The input tensor. multiplier : int The integer multiplier of the fixed point constant. - a_max : float + shift : int The integer shift of the fixed point constant. Returns From d095a968dfd6f2f65b506b45da6f091228a8b19a Mon Sep 17 00:00:00 2001 From: powderluv Date: Sun, 17 Oct 2021 03:47:58 -0700 Subject: [PATCH 47/84] llvm 14 and above move TargetRegistry into MC (#9305) TEST=build with latest llvm --- src/target/llvm/llvm_common.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index b967c7ad44e0..fcc44fb8f95c 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -72,7 +72,11 @@ #include #include #include +#if TVM_LLVM_VERSION >= 140 +#include +#else #include +#endif #include #include #include From 151696fbc8808f128c4bf163fcfff5e485c095c5 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Sun, 17 Oct 2021 22:24:06 -0500 Subject: [PATCH 48/84] [unittests] Skip import of tvm.micro if micro-TVM was not enabled (#9301) --- tests/python/unittest/test_crt.py | 1 + tests/python/unittest/test_micro_project_api.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 62e68ab01ce5..9450a937a155 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -35,6 +35,7 @@ from tvm.topi.utils import get_const_tuple from tvm.topi.testing import conv2d_nchw_python +pytest.importorskip("tvm.micro.testing") from tvm.micro.testing import check_tune_log BUILD = True diff --git a/tests/python/unittest/test_micro_project_api.py b/tests/python/unittest/test_micro_project_api.py index b5e2a57c122c..e319318656ef 100644 --- a/tests/python/unittest/test_micro_project_api.py +++ b/tests/python/unittest/test_micro_project_api.py @@ -25,6 +25,8 @@ import pytest import tvm + +pytest.importorskip("tvm.micro") from tvm.micro import project_api From 5443c7cf6759b939f8425e2b7edb1dc2e21308c6 Mon Sep 17 00:00:00 2001 From: Gustavo Romero Date: Mon, 18 Oct 2021 14:54:30 -0300 Subject: [PATCH 49/84] [microTVM][RVM] Always destroy the VM if all tests pass (#8739) Currently base-box-tool 'test' command will skip destroying the test VM if a single provider is specified (i.e. --provider virtualbox) even if all tests pass. This is confusing (no warning is displayed to the user) and that will leave host resources (like USB devices necessary to run the test) locked by the VM. So if the user tries to run a program that uses the locked resource (e.g. openocd) cryptic failures might happen. Moreoever, even if all tests pass and more than one provider is specified but the option '--skip-build' is set a VM will be left running without notice. This commit changes that behavior by: 1. Always destroying the VM if the release test pass 2. Always keeping the VM up and running if a test fails 1. guarantees no resource remains locked by the VM without necessity. A new flag '--skip-destroy' is introduced in case the user still wants to keep a VM up and running if the release tests pass. 2. guarantees the VM where the test failed is left running for further inspection of the test that failed. Finally, for both 1. and 2. cases a proper message is displayed to the user to inform if a VM was left running or not and about what actions the user can take next accordingly to the test result in the VM. Signed-off-by: Gustavo Romero --- apps/microtvm/reference-vm/base-box-tool.py | 48 +++++++++++++++------ 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/apps/microtvm/reference-vm/base-box-tool.py b/apps/microtvm/reference-vm/base-box-tool.py index 3a5fd18cede7..42b90c661704 100755 --- a/apps/microtvm/reference-vm/base-box-tool.py +++ b/apps/microtvm/reference-vm/base-box-tool.py @@ -388,13 +388,15 @@ def test_command(args): microtvm_test_config["microtvm_board"] = args.microtvm_board providers = args.provider - provider_passed = {p: False for p in providers} release_test_dir = os.path.join(THIS_DIR, f"release-test-{args.platform}") - if args.skip_build: - assert len(providers) == 1, "--skip-build was given, but >1 provider specified" + if args.skip_build or args.skip_destroy: + assert ( + len(providers) == 1 + ), "--skip-build and/or --skip-destroy was given, but >1 provider specified" + test_failed = False for provider_name in providers: try: if not args.skip_build: @@ -408,18 +410,27 @@ def test_command(args): microtvm_test_config, args.test_device_serial, ) - provider_passed[provider_name] = True + + except subprocess.CalledProcessError: + test_failed = True + sys.exit( + f"\n\nERROR: Provider '{provider_name}' failed the release test. " + "You can re-run it to reproduce the issue without building everything " + "again by passing the --skip-build and specifying only the provider that failed. " + "The VM is still running in case you want to connect it via SSH to " + "investigate further the issue, thus it's necessary to destroy it manually " + "to release the resources back to the host, like a USB device attached to the VM." + ) finally: - if not args.skip_build and len(providers) > 1: + # if we reached out here do_run_release_test() succeeded, hence we can + # destroy the VM and release the resources back to the host if user haven't + # requested to not destroy it. + if not (args.skip_destroy or test_failed): subprocess.check_call(["vagrant", "destroy", "-f"], cwd=release_test_dir) shutil.rmtree(release_test_dir) - if not all(provider_passed[p] for p in provider_passed.keys()): - sys.exit( - "some providers failed release test: " - + ",".join(name for name, passed in provider_passed if not passed) - ) + print(f'\n\nThe release tests passed on all specified providers: {", ".join(providers)}.') def release_command(args): @@ -493,9 +504,20 @@ def parse_args(): "--skip-build", action="store_true", help=( - "If given, assume a box has already been built in " - "the release-test subdirectory. Attach a USB device to this box and execute the " - "release test script--do not delete it." + "If given, assume a box has already been built in the release-test subdirectory, " + "so use that box to execute the release test script. If the tests fail the VM used " + "for testing will be left running for further investigation and will need to be " + "destroyed manually. If all tests pass on all specified providers no VM is left running, " + "unless --skip-destroy is given too." + ), + ) + parser_test.add_argument( + "--skip-destroy", + action="store_true", + help=( + "Skip destroying the test VM even if all tests pass. Can only be used if a single " + "provider is specified. Default is to destroy the VM if all tests pass (and always " + "skip destroying it if a test fails)." ), ) parser_test.add_argument( From d23688c3e5aa51e69d13fa469abeca3c9e054714 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 19 Oct 2021 01:59:17 +0800 Subject: [PATCH 50/84] update block syntax (#9286) --- .../install/ubuntu_install_python_package.sh | 2 +- include/tvm/tir/function.h | 14 +- include/tvm/tir/stmt.h | 6 +- include/tvm/tir/transform.h | 4 +- python/gen_requirements.py | 2 +- python/tvm/script/context_maintainer.py | 29 +- python/tvm/script/parser.py | 41 +- python/tvm/script/tir/scope_handler.py | 215 +++-- python/tvm/script/tir/special_stmt.py | 385 ++++++++- python/tvm/te/operation.py | 10 +- python/tvm/tir/function.py | 12 +- python/tvm/tir/schedule/schedule.py | 215 ++--- python/tvm/tir/transform/transform.py | 4 +- src/printer/tvmscript_printer.cc | 174 +++-- src/tir/ir/script/script_complete.cc | 18 +- tests/python/integration/test_lower.py | 53 +- tests/python/unittest/test_lower_build.py | 8 +- .../unittest/test_meta_schedule_arg_info.py | 10 +- .../unittest/test_meta_schedule_builder.py | 36 +- .../unittest/test_meta_schedule_database.py | 26 +- .../unittest/test_meta_schedule_runner.py | 42 +- .../test_meta_schedule_search_strategy.py | 10 +- .../test_meta_schedule_space_generator.py | 10 +- .../test_meta_schedule_task_scheduler.py | 36 +- .../test_meta_schedule_tune_context.py | 10 +- .../unittest/test_te_create_primfunc.py | 68 +- ...t_tir_analysis_detect_buffer_access_lca.py | 57 +- ...st_tir_analysis_get_block_access_region.py | 58 +- .../unittest/test_tir_lower_match_buffer.py | 44 +- .../unittest/test_tir_schedule_block_scope.py | 24 +- .../test_tir_schedule_cache_read_write.py | 738 ++++++++++-------- .../unittest/test_tir_schedule_compute_at.py | 495 ++++++------ .../test_tir_schedule_compute_inline.py | 186 +++-- .../unittest/test_tir_schedule_error.py | 8 +- .../unittest/test_tir_schedule_for_kind.py | 116 +-- .../unittest/test_tir_schedule_reduction.py | 85 +- .../unittest/test_tir_schedule_reorder.py | 99 ++- .../unittest/test_tir_schedule_rfactor.py | 301 ++++--- .../unittest/test_tir_schedule_sampling.py | 6 +- .../unittest/test_tir_schedule_split_fuse.py | 177 +++-- .../unittest/test_tir_schedule_state.py | 56 +- .../test_tir_schedule_state_cached_flags.py | 204 +++-- .../test_tir_schedule_storage_align.py | 36 +- .../unittest/test_tir_schedule_trace.py | 18 +- .../unittest/test_tir_schedule_utilities.py | 6 +- tests/python/unittest/test_tir_specialize.py | 106 ++- ...est_tir_transform_compact_buffer_region.py | 90 +-- ..._tir_transform_convert_blocks_to_opaque.py | 16 +- .../test_tir_transform_flatten_buffer.py | 18 +- .../test_tir_transform_lower_init_block.py | 55 +- ..._plan_update_buffer_allocation_location.py | 167 ++-- .../unittest/test_tvmscript_complete.py | 100 ++- .../unittest/test_tvmscript_error_report.py | 223 ++++-- tests/python/unittest/test_tvmscript_ops.py | 35 +- .../unittest/test_tvmscript_roundtrip.py | 95 ++- tests/scripts/task_ci_setup.sh | 2 +- 56 files changed, 2967 insertions(+), 2094 deletions(-) diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index d1fa340ac37d..fb0f596d6552 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -36,6 +36,6 @@ pip3 install \ pytest-xdist \ requests \ scipy \ - synr==0.4.1 \ + synr==0.5.0 \ six \ tornado diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 23057f7140e4..e4a3d3d1e21b 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -199,9 +199,10 @@ class LinkedParam : public ObjectRef { * def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: * A = T.match_buffer(a, (m, n), "float32") * B = T.match_buffer(b, (m, n), "float32") - * - * with T.block([m, n], "") as [vi, vj]: - * B[vi, vj] = A[vi, vj] + * for i, j in T.grid(m, n): + * with T.block(): + * vi, vj = T.axis.remap("SS", [i, j]) + * B[vi, vj] = A[vi, vj] * \endcode * * Then we can make it specialized with given shapes or buffers. @@ -218,9 +219,10 @@ class LinkedParam : public ObjectRef { * def mem_copy_16_16(a: T.handle, b: T.handle) -> None: * A = T.match_buffer(a, (16, 16), "float32") * B = T.match_buffer(b, (16, 16), "float32") - * - * with T.block([16, 16], "") as [vi, vj]: - * B[vi, vj] = A[vi, vj] + * for i, j in T.grid(16, 16): + * with T.block(): + * vi, vj = T.axis.remap("SS", [i, j]) + * B[vi, vj] = A[vi, vj] * \endcode */ PrimFunc Specialize(PrimFunc func, const Map& param_map); diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5cd860b8e929..4f5772822d9e 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1078,9 +1078,9 @@ class MatchBufferRegion : public ObjectRef { * \note Block's body is parameterized by iter vars. * \code * - * with T.block([extent0, extent1, ...], name) as [v0, v1, ...]: - * T.bind(v0, value0) - * T.bind(v1, value1) + * with T.block(name): + * v0 = T.axis.S(domain, value0) + * v1 = T.axis.R(domain, value1) * ... * T.reads([buffer0[start:end, ...], ...]) * T.writes([buffer1[start:end, ...], ...]) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 017078bd7bf7..e6b0af9773d9 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -388,7 +388,7 @@ TVM_DLL Pass ConvertBlocksToOpaque(); * \code * * for i in range(0, 16): - * with T.block([]): + * with T.block(): * B = T.alloc_buffer(16, 16) * for j in range(0, 16): * B[i, j] = A[i, j] + 1 @@ -404,7 +404,7 @@ TVM_DLL Pass ConvertBlocksToOpaque(); * \code * * for i in range(0, 16): - * with T.block([]): + * with T.block(): * B = T.alloc_buffer(1, 16) * for j in range(0, 16): * B[0, j] = A[i, j] + 1 diff --git a/python/gen_requirements.py b/python/gen_requirements.py index bdaa58e1449a..e9f3772ee733 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -251,7 +251,7 @@ ("sphinx_autodoc_annotation", None), ("sphinx_gallery", None), ("sphinx_rtd_theme", None), - ("synr", "==0.4.1"), + ("synr", "==0.5.0"), ("tensorflow", None), ("tensorflow-estimator", None), ("tflite", None), diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 75566cf6e2c5..56d080857a7d 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -22,8 +22,10 @@ import tvm from tvm.ir import Span +from tvm.ir.expr import Range from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion from tvm.runtime import Object +from tvm.tir.expr import IterVar from .tir.node import BufferSlice @@ -41,10 +43,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(a, (16, 16), "float32") for i, j, k in T.grid(16, 16, 16): - with T.block([16, 16, T.reduce_axis(16)], "matmul") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) # iter_bindings = {vj: i, vj: j, vk: k} + with T.block("matmul"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j) + vk = T.axis.R(16, k) # iter_bindings = {vj: i, vj: j, vk: k} T.where(True) # predicate of the block_realize @@ -72,8 +74,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: """List[Buffer]: list of T.alloc_buffer statements in the block signature""" match_buffers: List[MatchBufferRegion] = [] """List[MatchBufferRegion]: list of T.match_buffer statements in the block signature""" - iter_bindings: Mapping[Var, PrimExpr] = {} - """Mapping[Var, PrimExpr]: map of block iter var to its values""" + iter_values: List[PrimExpr] = [] + """List[PrimExpr]: list of binding values for iter vars""" + iter_vars: List[IterVar] = [] + """List[PrimExpr]: list of iter vars in the block""" reads: Optional[List[BufferSlice]] = None """Optional[List[BufferSlice]]: list of T.reads statements in the block signature, None for not-visited""" @@ -91,7 +95,8 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: def __init__(self): self.alloc_buffers = [] self.match_buffers = [] - self.iter_bindings = {} + self.iter_values = [] + self.iter_vars = [] self.reads = None self.writes = None self.annotations = None @@ -112,8 +117,8 @@ class ContextMaintainer: """List[List[synr.ast.Node]]: The ast nodes insides the current scope""" block_info_stack: List[BlockInfo] = [] """List[BlockInfo]: The block info for the current block scope""" - loop_stack: List[List[Var]] = [] - """List[List[Var]]: List of loop vars inside the current block scope""" + loop_stack: Dict[Var, Range] = {} + """Dict[Var, Range]: The dict from loop var to its domain outside the block""" symbols: List[Dict[str, Union[Var, Buffer]]] = [] """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope""" @@ -137,7 +142,7 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No # scope context self.node_stack = [] self.block_info_stack = [] - self.loop_stack = [] + self.loop_stack = {} self.symbols = [] # function context self.func_params = [] @@ -183,8 +188,6 @@ def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None): The synr AST nodes in new scope """ self.enter_scope(nodes) - # Create a new loop stack for the new block - self.loop_stack.append([]) # Create a new BlockInfo for the new block self.block_info_stack.append(BlockInfo()) @@ -196,8 +199,6 @@ def exit_scope(self): def exit_block_scope(self): """Pop the inner most block scope, the function will call `exit_scope` implicitly""" self.exit_scope() - # Pop loop stack - self.loop_stack.pop() # Pop block_info self.block_info_stack.pop() diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index d5e79e8676c1..8610d91e9f07 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -377,12 +377,13 @@ def A(): """ if len(node.assignments) == 1: if not ( - isinstance(node.assignments[0].lhs, ast.Var) - and node.assignments[0].lhs.id.name == "__tvm_meta__" + len(node.assignments[0].lhs) == 1 + and isinstance(node.assignments[0].lhs[0], ast.Var) + and node.assignments[0].lhs[0].id.name == "__tvm_meta__" ): self.report_error( "The only top level assignments allowed are `__tvm_meta__ = ...`", - node.assignments[0].lhs.span, + node.assignments[0].span, ) self.init_meta( MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context) @@ -526,18 +527,19 @@ def transform_Assign(self, node): return self.parse_body(node) else: value = self.transform(node.rhs) - if not isinstance(node.lhs, ast.Var): + if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var): # This is a little confusing because it only is true when # we have taken this branch. We might need to clarify what # exectly is allowed in Assignments in tvmscript. self.report_error( "Left hand side of assignment must be an unqualified variable", - node.lhs.span, + node.span, ) + ast_var = node.lhs[0] var = tvm.te.var( - node.lhs.id.name, - self.parse_type(node.ty, node.lhs), - span=tvm_span_from_synr(node.lhs.span), + ast_var.id.name, + self.parse_type(node.ty, ast_var), + span=tvm_span_from_synr(ast_var.span), ) self.context.update_symbol(var.name, var, node) body = self.parse_body(node) @@ -596,7 +598,7 @@ def transform_For(self, node): For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) By now 1 pattern of For is supported: 1. for scope handler - for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/tir.range()/ + for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/ T.grid()/T.thread_binding() """ @@ -892,9 +894,20 @@ def transform_Attr(self, node): namespace. """ - if isinstance(node.object, ast.Var): - if self.match_tir_namespace(node.object.id.name): - func_name = "tir." + node.field.name + def get_full_attr_name(node: ast.Attr) -> str: + reverse_field_names = [node.field.name] + while isinstance(node.object, ast.Attr): + node = node.object + reverse_field_names.append(node.field.name) + if isinstance(node.object, ast.Var): + reverse_field_names.append(node.object.id.name) + return ".".join(reversed(reverse_field_names)) + + if isinstance(node.object, (ast.Var, ast.Attr)): + full_attr_name = get_full_attr_name(node) + attr_object, fields = full_attr_name.split(".", maxsplit=1) + if self.match_tir_namespace(attr_object): + func_name = "tir." + fields res = Registry.lookup(func_name) if res is not None: return res @@ -903,9 +916,7 @@ def transform_Attr(self, node): except TVMError as e: # Check if we got an attribute error if e.args[0].find("AttributeError"): - self.report_error( - f"Unregistered function `tir.{node.field.name}`.", node.field.span - ) + self.report_error(f"Unregistered function `tir.{fields}`.", node.span) else: raise e diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 487a71d4f077..4750ad7626e2 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -134,12 +134,14 @@ def enter_scope( if isinstance(node, synr.ast.With): vars = WithScopeHandler.get_optional_vars(node, context) if len(vars) != 1: - context.report_error("Unexpected number of vars", node.span) + context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span) name = vars[0].id.name var_span = vars[0].id.span elif isinstance(node, synr.ast.Assign): - name = node.lhs.id.name - var_span = node.lhs.id.span + if len(node.lhs) != 1: + context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span) + name = node.lhs[0].id.name + var_span = node.lhs[0].id.span else: raise Exception("Internal Bug") @@ -247,42 +249,16 @@ def let(var, value, span): @register class Block(WithScopeHandler): - """With scope handler T.block(extents, name) as iter_vars""" + """With scope handler T.block(name)""" def __init__(self): - def block(axes=None, name_hint: str = "", span: Optional[Span] = None): + def block(name_hint: str = "", span: Optional[Span] = None): assert ( self.node and self.context and self.body ), "call 'exit_scope' before 'enter_scope'" block_info = self.context.block_info_stack[-1] - if axes is None: - axes = [] - if len(axes) != len(self.block_vars): - self.context.report_error( - "Inconsistent number of block vars, " - + f"there are {len(axes)} axes but {len(self.block_vars)} block vars. " - + "The number of block vars should match the number of axes.", - self.node.span, - ) - block_iters: List[IterVar] = [] - for i, axis in enumerate(axes): - axis = tvm.runtime.convert(axis) - if isinstance(axis, tvm.tir.PrimExpr): - block_var_dom = Range.from_min_extent(0, axis) - block_iters.append(IterVar(block_var_dom, self.block_vars[i], 0)) - elif isinstance(axis, Range): - block_iters.append(IterVar(axis, self.block_vars[i], 0)) - elif isinstance(axis, IterVar): - block_iters.append(IterVar(axis.dom, self.block_vars[i], axis.iter_type)) - else: - self.context.report_error( - "Invalid argument of T.block(), " - + f"expected PrimExpr, Range or IterVar, but got {type(axis)}", - self.node.span, - ) # create block read/write regions - reads: List[BufferRegion] = ( [buffer_slice_to_region(read) for read in block_info.reads] if block_info.reads @@ -301,7 +277,7 @@ def block(axes=None, name_hint: str = "", span: Optional[Span] = None): if region_detect_mask != 0: annotations["tir.script_parsing_detect_access"] = region_detect_mask inner = tvm.tir.Block( - block_iters, + block_info.iter_vars, reads, writes, name_hint, @@ -312,35 +288,13 @@ def block(axes=None, name_hint: str = "", span: Optional[Span] = None): annotations, span, ) - # create block var iter binding - values: List[PrimExpr] - if not block_info.iter_bindings: - values = self.context.loop_stack[-2].copy() - if len(block_iters) == 0: - # It is an opaque block without any bindings - values = [] - elif len(values) == 0: - values = [tvm.tir.const(float("nan"), dtype="float32")] * len(block_iters) - elif len(values) != len(block_iters): - self.context.report_error( - "Number of block iter var and outer loop nesting mismatch, " - + f"{len(block_iters)} block iter vars but {len(values)} loops", - self.node.span, - ) - else: - for block_var in self.block_vars: - if block_var not in block_info.iter_bindings: - self.context.report_error( - "Missing block iter var binding for " + block_var.name, - self.node.span, - ) - values = [block_info.iter_bindings[block_var] for block_var in self.block_vars] + assert len(block_info.iter_vars) == len(block_info.iter_values) predicate = ( tvm.tir.const(True, "bool") if block_info.predicate is None else block_info.predicate ) - body = tvm.tir.BlockRealize(values, predicate, inner, span) + body = tvm.tir.BlockRealize(block_info.iter_values, predicate, inner, span) return body super().__init__(func=block, concise_scope=False, def_symbol=True) @@ -358,10 +312,13 @@ def enter_scope( node, synr.ast.With ), f"BlockScopeHandler expected to work on synr.ast.With but got {type(node)}" - vars = WithScopeHandler.get_optional_vars(node, context) - self.block_vars = [tvm.te.var(var.id.name) for var in vars] - for block_var in self.block_vars: - context.update_symbol(block_var.name, block_var, node) + optional_vars = [var.id.name for var in WithScopeHandler.get_optional_vars(node, context)] + if optional_vars: + context.report_error( + f"Block expected no optional_vars (e.g., `x` in `with block() as x`), " + f"but got {optional_vars}", + node.span, + ) @register @@ -378,12 +335,38 @@ def init(span: Span = None): super().__init__(func=init, concise_scope=False, def_symbol=True) +class LoopInfo: + """Helper class for loop information""" + + loop_var: Var + begin: PrimExpr + extent: PrimExpr + kind: ForKind + thread_binding: Optional[str] + annotations: Optional[Mapping[str, Object]] + + def __init__( + self, + begin: PrimExpr, + extent: PrimExpr, + kind: ForKind, + thread_binding: Optional[str] = None, + annotations: Optional[Mapping[str, Object]] = None, + ) -> None: + self.begin = begin + self.extent = extent + self.kind = kind + self.thread_binding = thread_binding + self.annotations = annotations + + class ForScopeHandler(ScopeHandler): """Base class for all for scope handlers""" def __init__(self, func): super().__init__(func) - self.loop_vars: Optional[List[Var]] = None + self.loop_vars: List[Var] = [] + self.loop_info: List[LoopInfo] = [] def enter_scope( self, @@ -415,12 +398,23 @@ def enter_scope( span, ) + self.node = node + self.context = context + # generate loop vars self.loop_vars = [ tvm.te.var(name, dtype="int32", span=span) for name, span in zip(loop_var_names, spans) ] - for loop_var in self.loop_vars: + # collect loop infos by calling self.func + call_with_error_reporting(context.report_error, span, self.func, *arg_list) + if len(self.loop_vars) != len(self.loop_info): + self.context.report_error( + f"Inconsistent number of vars and loops, got {len(self.loop_vars)} " + + f"vs {len(self.loop_info)}", + self.node.span, + ) + for loop_var, loop_info in zip(self.loop_vars, self.loop_info): context.update_symbol(loop_var.name, loop_var, node) - context.loop_stack[-1].append(loop_var) + context.loop_stack[loop_var] = Range.from_min_extent(loop_info.begin, loop_info.extent) def exit_scope( self, @@ -430,19 +424,34 @@ def exit_scope( span: synr.ast.Span, ): assert self.loop_vars, "call 'exit_scope' before 'enter_scope'" - for _ in self.loop_vars: - context.loop_stack[-1].pop() - return super().exit_scope(node, context, arg_list, span) + for loop_var in self.loop_vars: + context.loop_stack.pop(loop_var) + # Use assert here since we have check it in `enter_scope` + assert len(self.loop_vars) == len(self.loop_info) + + body = self.body + for var, info in zip(reversed(self.loop_vars), reversed(self.loop_info)): + body = tvm.tir.For( + var, + info.begin, + info.extent, + info.kind, + body, + info.thread_binding, + info.annotations, + span=tvm_span_from_synr(span), + ) - def create_loop( + return body + + def create_loop_info( self, begin: PrimExpr, end: PrimExpr, kind: ForKind, thread_binding: Optional[str] = None, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, - ) -> tvm.tir.For: + ) -> None: """ Helper function for creating For in TVM Script parser. @@ -471,30 +480,16 @@ def create_loop( for : For The constructed For. """ - assert ( - self.loop_vars and self.context and self.node - ), "call 'exit_scope' before 'enter_scope'" - if len(self.loop_vars) != 1: - self.context.report_error( - f"Expected exactly one loop var, but got {self.loop_vars}", self.node.span - ) + assert self.context and self.node, "call 'exit_scope' before 'enter_scope'" extent = end if begin == 0 else self.context.analyzer.simplify(end - begin) - annos: Mapping[str, Object] = {} + self.annotations: Mapping[str, Object] = {} if annotations is not None: - annos = { + self.annotations = { key: tvm.tir.StringImm(val) if isinstance(val, str) else val for key, val in annotations.items() } - return tvm.tir.For( - self.loop_vars[0], - begin, - extent, - kind, - self.body, - thread_binding=thread_binding, - annotations=annos, - span=span, - ) + + self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations)) @register @@ -506,9 +501,8 @@ def serial( begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - return self.create_loop(begin, end, ForKind.SERIAL, annotations=annotations, span=span) + self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations) super().__init__(serial) @@ -522,11 +516,8 @@ def parallel( begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - return self.create_loop( - begin, end, ForKind.PARALLEL, annotations=annotations, span=span - ) + self.create_loop_info(begin, end, ForKind.PARALLEL, annotations=annotations) super().__init__(parallel) @@ -540,11 +531,8 @@ def vectorized( begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - return self.create_loop( - begin, end, ForKind.VECTORIZED, annotations=annotations, span=span - ) + self.create_loop_info(begin, end, ForKind.VECTORIZED, annotations=annotations) super().__init__(vectorized) @@ -558,11 +546,8 @@ def unroll( begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - return self.create_loop( - begin, end, ForKind.UNROLLED, annotations=annotations, span=span - ) + self.create_loop_info(begin, end, ForKind.UNROLLED, annotations=annotations) super().__init__(unroll) @@ -577,16 +562,14 @@ def thread_binding( end: PrimExpr, thread: str, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread, span=span) - return self.create_loop( + thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread) + self.create_loop_info( begin, end, ForKind.THREAD_BINDING, thread_binding=thread_iter_var, annotations=annotations, - span=span, ) super().__init__(thread_binding) @@ -603,12 +586,11 @@ def for_range( begin: PrimExpr, end: PrimExpr = None, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): if end is None: end = begin begin = 0 - return self.create_loop(begin, end, ForKind.SERIAL, annotations=annotations, span=span) + self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations) super().__init__(for_range) @@ -621,19 +603,8 @@ class Grid(ForScopeHandler): """For scope handler T.grid(extents)""" def __init__(self): - def grid(*extents: List[PrimExpr], span: Span): - assert ( - self.node and self.context and self.loop_vars - ), "call 'exit_scope' before 'enter_scope'" - if len(self.loop_vars) != len(extents): - self.context.report_error( - "Inconsistent number of loop vars and extents, " - + f"got {len(self.loop_vars)} vs {len(extents)}", - self.node.span, - ) - body = self.body - for loop_var, extent in zip(reversed(self.loop_vars), reversed(extents)): - body = tvm.tir.For(loop_var, 0, extent, ForKind.SERIAL, body, span=span) - return body + def grid(*extents: List[PrimExpr]): + for extent in extents: + self.create_loop_info(0, extent, ForKind.SERIAL) super().__init__(grid) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 69cf15f493de..de212352f3e4 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -21,17 +21,18 @@ import synr from synr import ast +from tvm.ir.expr import PrimExpr, Range import tvm.tir from tvm.runtime import Object from tvm import te from tvm.ir import Span -from tvm.tir import IntImm +from tvm.tir import IntImm, IterVar from .node import BufferSlice from .utils import buffer_slice_to_region -from ..context_maintainer import ContextMaintainer +from ..context_maintainer import BlockInfo, ContextMaintainer from ..registry import register from ..utils import ( get_param_list, @@ -132,9 +133,10 @@ def match_buffer( buffer_type="default", span=None, ): - if not isinstance(self.node, ast.Assign): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: self.context.report_error( - "match_buffer must be assigned to a buffer, e.g. A = match_buffer(...)", + "`match_buffer` must be assigned to a single buffer, " + "e.g. A = match_buffer(...)", self.node.span, ) if strides is None: @@ -143,10 +145,11 @@ def match_buffer( offset_factor = convert_to_int( offset_factor, "offset_factor", self.context.report_error, self.node.span ) + buffer_name: str = self.node.lhs[0].id.name buffer = tvm.tir.decl_buffer( shape, dtype, - self.node.lhs.id.name, + buffer_name, data, strides, elem_offset, @@ -173,7 +176,7 @@ def match_buffer( + str(type(param)), self.node.rhs.params[0].span, ) - self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) + self.context.update_symbol(buffer_name, buffer, self.node) super().__init__(match_buffer, def_symbol=True) @@ -201,9 +204,9 @@ def buffer_decl( buffer_type="default", span=None, ): - if not isinstance(self.node, ast.Assign): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: self.context.report_error( - "buffer_decl must be assigned to a buffer, e.g. A = buffer_decl(...)", + "`buffer_decl` must be assigned to a single buffer, e.g. A = buffer_decl(...)", self.node.span, ) @@ -213,10 +216,11 @@ def buffer_decl( offset_factor = convert_to_int( offset_factor, "offset_factor", self.context.report_error, self.node.span ) + buffer_name: str = self.node.lhs[0].id.name buffer = tvm.tir.decl_buffer( shape, dtype, - self.node.lhs.id.name, + buffer_name, data, strides, elem_offset, @@ -226,7 +230,7 @@ def buffer_decl( buffer_type, span=span, ) - self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) + self.context.update_symbol(buffer_name, buffer, self.node) return buffer super().__init__(buffer_decl, def_symbol=True) @@ -257,9 +261,10 @@ def alloc_buffer( buffer_type="default", span=None, ): - if not isinstance(self.node, ast.Assign): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: self.context.report_error( - "alloc_buffer must be assigned to a buffer, e.g. A = alloc_buffer(...)", + "`alloc_buffer` must be assigned to a single buffer, " + "e.g. A = alloc_buffer(...)", self.node.span, ) @@ -269,10 +274,11 @@ def alloc_buffer( offset_factor = convert_to_int( offset_factor, "offset_factor", self.context.report_error, self.node.span ) + buffer_name: str = self.node.lhs[0].id.name buffer = tvm.tir.decl_buffer( shape, dtype, - self.node.lhs.id.name, + buffer_name, data, strides, elem_offset, @@ -283,32 +289,11 @@ def alloc_buffer( span=span, ) self.context.current_block_scope().alloc_buffers.append(buffer) - self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) + self.context.update_symbol(buffer_name, buffer, self.node) super().__init__(alloc_buffer, def_symbol=True) -@register -class BlockVarBind(SpecialStmt): - """Special function bind(block_iter, binding_value) - - Example - ------- - .. code-block:: python - - T.bind(vx, i) - """ - - def __init__(self): - def bind(iter_var, values, span=None): - block_scope = self.context.current_block_scope() - if iter_var in block_scope.iter_bindings: - self.context.report_error("Duplicate iter_var bindings of " + str(iter_var), span) - block_scope.iter_bindings[iter_var] = values - - super().__init__(bind, def_symbol=False) - - @register class BlockReads(SpecialStmt): """Special function reads([read_buffer_regions]) @@ -412,6 +397,315 @@ def block_attr(attrs: Mapping[str, Object], span: Span = None): super().__init__(block_attr, def_symbol=False) +class BlockAxis(SpecialStmt): + """Special stmt for defining a spatial block axis + axis.S(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.S(128, i * 4 + j) + """ + + def axis( + self, + var_name: str, + dom: Union[PrimExpr, Range], + value: PrimExpr, + iter_type: int, + span: Optional[Span] = None, + ) -> None: + """ + Helper function for creating block axis + + Parameters + ---------- + var_name : str + The name_hint of var + + dom : Union[PrimExpr, Range] + The iter domain. + + value : PrimExpr + The binding value + + iter_type : int + The iteration type. + + span : Optional[Span] + The location of this for in the source code. + """ + assert self.context, "call 'exit_scope' before 'enter_scope'" + block_scope: BlockInfo = self.context.current_block_scope() + if var_name in [iter_var.var.name for iter_var in block_scope.iter_vars]: + self.context.report_error("Duplicate block axis " + var_name, self.node.span) + + block_var = tvm.tir.Var(var_name, dtype="int32") + dom = tvm.runtime.convert(dom) + if isinstance(dom, PrimExpr): + dom = tvm.ir.Range.from_min_extent(0, dom) + elif not isinstance(dom, tvm.ir.Range): + self.context.report_error( + f"Block axis domain expected PrimExpr or Range, but got {type(value)}", + self.node.span, + ) + value = tvm.runtime.convert(value) + if not isinstance(value, PrimExpr): + self.context.report_error( + f"Block axis value expected PrimExpr, but got {type(value)}", + self.node.span, + ) + iter_var = tvm.tir.IterVar(dom, block_var, iter_type) + block_scope.iter_vars.append(iter_var) + block_scope.iter_values.append(value) + self.context.update_symbol(var_name, block_var, self.node) + + +@register +class BlockAxisSpatial(BlockAxis): + """Special stmt for defining a spatial block axis + axis.spatial(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.spatial(128, k) + """ + + def __init__(self): + def axis_spatial( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.spatial` must be assigned to a var, e.g. vi = axis.spatial(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar) + + super().__init__(axis_spatial, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.spatial", get_param_list(self.func) + + +@register +class BlockAxisS(BlockAxis): + """The sugar special stmt for defining a spatial block axis + axis.S(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.S(128, k) + """ + + def __init__(self): + def axis_spatial( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.S` must be assigned to a var, e.g. vi = axis.S(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar) + + super().__init__(axis_spatial, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.S", get_param_list(self.func) + + +@register +class BlockAxisReduce(BlockAxis): + """Special stmt for defining a reduce block axis + axis.reduce(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.reduce(128, k) + """ + + def __init__(self): + def axis_reduce( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.reduce` must be assigned` to a var, e.g. vi = axis.reduce(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce) + + super().__init__(axis_reduce, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.reduce", get_param_list(self.func) + + +@register +class BlockAxisR(BlockAxis): + """The sugar special stmt for defining a reduce block axis + axis.R(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.R(128, k) + """ + + def __init__(self): + def axis_reduce( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.R` must be assigned to a var, e.g. vi = axis.R(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce) + + super().__init__(axis_reduce, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.R", get_param_list(self.func) + + +@register +class BlockAxisScan(BlockAxis): + """Special stmt for defining a ordered block axis + axis.scan(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.scan(128, k) + """ + + def __init__(self): + def axis_scan( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.scan` must be assigned to a var, e.g. vi = axis.scan(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.Ordered) + + super().__init__(axis_scan, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.scan", get_param_list(self.func) + + +@register +class BlockAxisOpaque(BlockAxis): + """Special stmt for defining a opaque block axis + axis.opaque(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.opaque(128, k) + """ + + def __init__(self): + def axis_opaque( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.opaque` must be assigned to a var, e.g. vi = axis.opaque(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DimInfo) + + super().__init__(axis_opaque, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.opaque", get_param_list(self.func) + + +@register +class BlockAxisRemap(BlockAxis): + """Special stmt for remapping loops vars to block axes. + axis.remap(iter_type, iter_value) + + Note + ---- + Iter_type is a string consisting of 'S' and 'R', where 'S' means + for spatial and 'R' means for reduce. + + Example + ------- + .. code-block:: python + + vi, vj = T.axis.remap("SS", [i, j]) + """ + + def __init__(self): + def axis_remap(iter_types: str, loop_vars: List[tvm.tir.expr.Var], span: Span = None): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) >= 1: + self.context.report_error( + "`axis.remap` must be assigned to one or more vars, " + "e.g. vi, vj = axis.remap(...)", + self.node.span, + ) + var_num: int = len(self.node.lhs) + if var_num != len(iter_types): + self.context.report_error( + f"`iter_type` expected {var_num} charactor(s), " + f"but got {len(iter_types)}: {iter_types}", + span, + ) + if var_num != len(loop_vars): + self.context.report_error( + f"`iter_type` expected {var_num} loop var(s), " + f"but got {len(loop_vars)}: {loop_vars}", + span, + ) + for var, iter_ty, loop_var in zip(self.node.lhs, iter_types, loop_vars): + iter_type: int + if iter_ty == "S": + iter_type = IterVar.DataPar + elif iter_ty == "R": + iter_type = IterVar.CommReduce + else: + self.context.report_error( + f'`iter_type` only expected "S" (for spatial) or "R" (for reduce), ' + f'but got "{iter_ty}"', + span, + ) + + if not isinstance(loop_var, tvm.tir.expr.Var): + self.context.report_error( + f"Values of `axis.remap` expected single loop var, but got {loop_var}", + loop_var.span, + ) + loops = self.context.loop_stack + if loop_var not in loops: + self.context.report_error( + f"Cannot find loop var {loop_var} in loop nesting.", + span, + ) + self.axis(var.id.name, loops[loop_var], loop_var, iter_type) + + super().__init__(axis_remap, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.remap", get_param_list(self.func) + + @register class BlockPredicate(SpecialStmt): """Special function where(predicate) @@ -449,7 +743,12 @@ def var(dtype, span): assert isinstance( self.node, ast.Assign ), f"VarDef expected ast.Assign but got {type(self.node)}" - v = te.var(self.node.lhs.id.name, dtype, span=span) + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"VarDef expected assign to only one var, but got {names}", span + ) + v = te.var(names[0], dtype, span=span) self.context.update_symbol(v.name, v, self.node) super().__init__(var, def_symbol=True) @@ -464,8 +763,13 @@ def buffer_var(dtype, storage_scope, span): assert isinstance( self.node, ast.Assign ), f"BufferVarDef expected ast.Assign but got {type(self.node)}" + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"VarDef expected assign to only one var, but got {names}", span + ) ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) - v = te.var(self.node.lhs.id.name, ptr_type, span=span) + v = te.var(names[0], ptr_type, span=span) self.context.update_symbol(v.name, v, self.node) super().__init__(buffer_var, def_symbol=True) @@ -480,7 +784,12 @@ def env_thread(env_name, span): assert isinstance( self.node, ast.Assign ), f"EnvThread expected ast.Assign but got {type(self.node)}" - v = te.var(self.node.lhs.id.name, span=span) + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"VarDef expected assign to only one var, but got {names}", span + ) + v = te.var(names[0], span=span) self.context.func_var_env_dict[v] = env_name self.context.update_symbol(v.name, v, self.node) diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 681e322b2082..cb0305d49e4a 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -467,10 +467,12 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: - with T.init(): - C[i, j] = 0.0 - C[i, j] += A[i, k] * B[j, k] + for i, j, k in T.grip(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] Returns ------- diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 6a90924912b1..b002ace0e400 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -108,8 +108,10 @@ def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: A = T.match_buffer(a, (m, n), "float32") B = T.match_buffer(b, (m, n), "float32") - with T.block([m, n], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(m, n): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] Then we can make it specialized with given shapes or buffers. @@ -129,8 +131,10 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") - with T.block([16, 16], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] Returns ------- diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 09a52d2e7037..786982cf704c 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -397,7 +397,8 @@ def before_fuse(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do fuse: @@ -419,9 +420,9 @@ def after_fuse(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) # the 2 loops are fused into 1 for i_j_fused in T.serial(0, 16384): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, tir.floordiv(i_j_fused, 128)) - T.bind(vj, T.floormod(i_j_fused, 128)) + with T.block("B"): + vi = T.axis.S(128, T.floordiv(i_j_fused, 128)) + vj = T.axis.S(128, T.floormod(i_j_fused, 128)) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -468,7 +469,8 @@ def before_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B") as [vi, vj]: + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do split: @@ -490,9 +492,9 @@ def after_split(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) # the original loop is split into 2 loops for i0, i1, j in T.grid(2, 64, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, ((i0*64) + i1)) - T.bind(vj, j) + with T.block("B"): + vi = T.axis.S(128, i0 * 64 + i1) + vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -529,7 +531,8 @@ def before_reorder(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do reorder: @@ -551,9 +554,8 @@ def after_reorder(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) # Here j and i are reordered for j, i in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -586,9 +588,8 @@ def before_parallel(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do parallel: @@ -609,9 +610,8 @@ def after_parallel(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.parallel(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -642,9 +642,8 @@ def before_vectorize(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do vectorize: @@ -665,9 +664,8 @@ def after_vectorize(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.serial(0, 128): for j in T.vectorized(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -706,9 +704,8 @@ def before_bind(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do bind: @@ -730,9 +727,8 @@ def after_bind(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.thread_binding(0, 128, thread = "blockIdx.x"): for j in T.thread_binding(0, 128, thread = "threadIdx.x"): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -758,9 +754,8 @@ def before_unroll(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do unroll: @@ -781,9 +776,8 @@ def after_unroll(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.unroll(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -825,7 +819,8 @@ def before_cache_read(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and cache_read: @@ -847,10 +842,12 @@ def after_cache_read(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) A_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block([128, 128], "A_local") as [vi, vj]: + with T.block("A_local"): + vi, vj = T.axis.remap("SS", [i, j]) A_local[vi, vj] = A[vi, vj] for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A_local[vi, vj] * 2.0 """ @@ -893,7 +890,8 @@ def before_cache_write(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and cache_write: @@ -915,10 +913,12 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) B_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block([128, 128], "A_local") as [vi, vj]: + with T.block("A_local"): + vi, vj = T.axis.remap("SS", [i, j]) B_local[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = B_local[vi, vj] """ @@ -974,10 +974,14 @@ def before_compute_at(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do compute-at: @@ -1000,14 +1004,12 @@ def after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in T.serial(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 """ @@ -1061,10 +1063,14 @@ def before_reverse_compute_at(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do reverse-compute-at: @@ -1087,14 +1093,12 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in T.serial(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 """ @@ -1135,10 +1139,14 @@ def before_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do compute-inline: @@ -1156,8 +1164,10 @@ def before_inline(a: T.handle, c: T.handle) -> None: def after_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ _ffi_api.ScheduleComputeInline(self, block) # type: ignore # pylint: disable=no-member @@ -1195,10 +1205,14 @@ def before_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do reverse-compute-inline: @@ -1216,8 +1230,10 @@ def before_inline(a: T.handle, c: T.handle) -> None: def after_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ _ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore # pylint: disable=no-member @@ -1384,8 +1400,9 @@ def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV: def before_rfactor(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128), - T.reduce_axis(0, 128)], "B") as [vii, vi, vj]: + for ii, i, j in T.grid(128, 128, 128): + with T.block("B"): + vii, vi, vj = T.axis.remap("SRR", [ii, i, j]) with T.init(): B[vii] = 0.0 B[vii] = B[vii] + A[vii, vi, vj] @@ -1408,14 +1425,18 @@ def after_rfactor(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128]) B_rf = T.alloc_buffer([128, 128]) - with T.block([128, 128, T.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]: - with T.init(): - B_rf[vi2, vii] = 0.0 - B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2]) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]: - with T.init(): - B[vii_1] = 0.0 - B[vii_1] = (B[vii_1] + B_rf[vi2_1, vii_1]) + for i2, ii, i in T.grid(128, 128, 128): + with T.block("B_rf"): + vi2, vii, vi = T.axis.remap("SSR", [i2, ii, i]) + with T.init(): + B_rf[vi2, vii] = 0.0 + B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2]) + for ii, i2 in T.grid(128, 128): + with T.block("B"): + vii, vi2 = T.axis.remap("SR", [ii, i2]) + with T.init(): + B[vii] = 0.0 + B[vii] = B[vii] + B_rf[vi2, vii] Note @@ -1483,10 +1504,14 @@ def before_storage_align(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do storage_align: @@ -1505,11 +1530,15 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - T.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + T.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 After lowering passes, buffer B will have strides as [129, 1]. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 1abba77a801f..722810e9aa5b 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -628,7 +628,7 @@ def CompactBufferAllocation(): .. code-block:: python for i in range(0, 16): - with T.block([]): + with T.block(): B = T.alloc_buffer(16, 16) for j in range(0, 16): B[i, j] = A[i, j] + 1 @@ -643,7 +643,7 @@ def CompactBufferAllocation(): .. code-block:: python for i in range(0, 16): - with T.block([]): + with T.block(): B = T.alloc_buffer(1, 16) for j in range(0, 16): B[0, j] = A[i, j] + 1 diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index fa74e56f491c..13e4cfcd30ba 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -22,10 +22,10 @@ * \brief Printer class to print Tensor IR to python syntax script */ -#include #include #include #include +#include #include #include #include @@ -128,7 +128,17 @@ class TVMScriptPrinter : public StmtFunctor, /*! \brief the number of current node */ int current_num_; /*! \brief loop stack without annotations */ - std::vector loop_stack_; + std::vector simple_loop_stack_; + /*! \brief the maps from loop_vars to the loops */ + std::unordered_map loop_var_map_; + /*! + * \brief simple block vars remap from loop vars + * simple_remap requires: + * 1. block var iter type is kDataPar or kCommReduce + * 2. value is a single Var, which is a loop_var outside the block + * 3. The iter range is equal to loop range + */ + std::vector> block_var_remaps_; Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override; @@ -193,7 +203,9 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintArray(const ArrayNode* op); Doc PrintBuffer(const BufferNode* op); Doc AllocBufferDeclaration(const Buffer& buf); - Doc PrintBlockVar(const BlockNode* op); + Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value); + Doc PrintBlockVarRemaps(); + Doc PrintBlockVars(const BlockRealizeNode* op); Doc PrintBlockAttr(const BlockRealizeNode* op); Doc PrintBlockBody(const BlockNode* op); Doc PrintBufferRegion(const BufferRegionNode* op); @@ -821,21 +833,23 @@ Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) { Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { Doc doc; var_not_in_headers_.insert(op->loop_var.get()); + loop_var_map_[op->loop_var.get()] = GetRef(op); const auto* body = op->body.as(); bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() && is_zero(op->min); - if (simple_loop) loop_stack_.push_back(GetRef(op)); + if (simple_loop) simple_loop_stack_.push_back(GetRef(op)); // It is a loop that can be compressed, let the loops below print it out if (simple_loop && body != nullptr) { Doc result = Print(GetRef(body)); TryDeallocVar(op->loop_var); + loop_var_map_.erase(op->loop_var.get()); return result; } // It is a loop that can not be compressed - bool print_above = !loop_stack_.empty(); + bool print_above = !simple_loop_stack_.empty(); // print loops above if needed if (print_above) { doc << PrintLoopStack(); - loop_stack_.clear(); + simple_loop_stack_.clear(); } if (!simple_loop) { // print current loop if needed @@ -847,6 +861,7 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } TryDeallocVar(op->loop_var); + loop_var_map_.erase(op->loop_var.get()); return doc; } @@ -901,52 +916,99 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } -Doc TVMScriptPrinter::PrintBlockVar(const BlockNode* op) { +Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) { Doc doc; - doc << "with " << tir_prefix_ << ".block(["; - std::vector block_var_docs; - for (const auto& iter_var : op->iter_vars) { - Doc block_var_doc; - if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) { - block_var_doc << Print(iter_var->dom->extent); + doc << Print(iter_var->var) << " = " << tir_prefix_ << ".axis."; + switch (iter_var->iter_type) { + case kDataPar: + doc << "spatial"; + break; + case kCommReduce: + doc << "reduce"; + break; + case kOrdered: + doc << "scan"; + break; + case kOpaque: + doc << "opaque"; + break; + default: + LOG(FATAL) << "Unknown block var iter type: " << iter_var->iter_type; + break; + } + doc << "("; + const Range& dom = iter_var->dom; + if (is_zero(dom->min)) { + doc << Print(dom->extent); + } else { + doc << "(" << Print(dom->min) << ", " << Print(dom->min + dom->extent) << ")"; + } + doc << ", " << Print(value) << ")"; + return doc; +} + +Doc TVMScriptPrinter::PrintBlockVarRemaps() { + ICHECK(!block_var_remaps_.empty()); + if (block_var_remaps_.size() == 1) { + const IterVar& iter_var = block_var_remaps_[0].first; + const PrimExpr& value = block_var_remaps_[0].second; + return PrintBlockVar(iter_var, value); + } + Doc doc; + std::vector iter_vars, iter_values; + std::string iter_type; + for (const auto& pair : block_var_remaps_) { + const IterVar& iter_var = pair.first; + const PrimExpr& value = pair.second; + iter_vars.push_back(Print(iter_var->var)); + iter_values.push_back(Print(value)); + if (iter_var->iter_type == kDataPar) { + iter_type += "S"; + } else if (iter_var->iter_type == kCommReduce) { + iter_type += "R"; } else { - block_var_doc << tir_prefix_ << "."; - switch (iter_var->iter_type) { - case kDataPar: - block_var_doc << "range"; - break; - case kCommReduce: - block_var_doc << "reduce_axis"; - break; - case kOrdered: - block_var_doc << "scan_axis"; - break; - case kOpaque: - block_var_doc << "opaque_axis"; - break; - default: - LOG(FATAL) << "Unknown block var iter type: " << iter_var->iter_type; - break; - } - block_var_doc << "(" << Print(iter_var->dom->min) << ", " - << Print(iter_var->dom->min + iter_var->dom->extent) << ")"; + ICHECK(false); } - block_var_docs.push_back(block_var_doc); - } - doc << PrintSep(block_var_docs, Doc::Text(", ")) << "]"; - if (!op->name_hint.empty()) { - doc << ", " << Doc::StrLiteral(op->name_hint); } - doc << ")"; - std::vector block_var_names; - for (const auto& iter_var : op->iter_vars) { + doc << PrintSep(iter_vars, Doc::Text(", ")) << " = " << tir_prefix_ << ".axis.remap(" + << Doc::StrLiteral(iter_type) << ", [" << PrintSep(iter_values, Doc::Text(", ")) << "])"; + return doc; +} + +Doc TVMScriptPrinter::PrintBlockVars(const BlockRealizeNode* op) { + Doc doc; + const auto* block_op = op->block.as(); + ICHECK_EQ(block_op->iter_vars.size(), op->iter_values.size()); + tir::ExprDeepEqual expr_equal; + + auto is_simple_remap = [this, &expr_equal](const IterVar& iter_var, + const PrimExpr& value) -> bool { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) return false; + if (!value->IsInstance()) return false; + const Var& var = Downcast(value); + auto it = loop_var_map_.find(var.get()); + return it != loop_var_map_.end() && expr_equal(it->second->min, iter_var->dom->min) && + expr_equal(it->second->extent, iter_var->dom->extent); + }; + + for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { + const IterVar& iter_var = block_op->iter_vars[i]; + const PrimExpr& value = op->iter_values[i]; var_not_in_headers_.insert(iter_var->var.get()); - block_var_names.push_back(Print(iter_var->var)); + if (is_simple_remap(iter_var, value)) { + block_var_remaps_.push_back(std::make_pair(iter_var, value)); + } else { + if (!block_var_remaps_.empty()) { + doc << Doc::NewLine() << PrintBlockVarRemaps(); + block_var_remaps_.clear(); + } + doc << Doc::NewLine() << PrintBlockVar(iter_var, value); + } } - if (!block_var_names.empty()) { - doc << " as [" << PrintSep(block_var_names, Doc::Text(", ")) << "]"; + if (!block_var_remaps_.empty()) { + doc << Doc::NewLine() << PrintBlockVarRemaps(); + block_var_remaps_.clear(); } - doc << ":"; return doc; } @@ -957,10 +1019,6 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) { if (!is_one(op->predicate)) { block_attr_doc << Doc::NewLine() << tir_prefix_ << ".where(" << Print(op->predicate) << ")"; } - for (size_t i = 0; i < block_op->iter_vars.size(); ++i) - block_attr_doc << Doc::NewLine() << tir_prefix_ << ".bind(" - << Print(block_op->iter_vars[i]->var) << ", " << Print(op->iter_values[i]) - << ")"; block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads(" << Print(block_op->reads) << ")"; block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes(" << Print(block_op->writes) << ")"; if (!block_op->annotations.empty()) { @@ -994,12 +1052,18 @@ Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) { Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { const auto* block_op = op->block.as(); // print block name and block vars - Doc doc = PrintBlockVar(block_op); + Doc doc; + doc << "with " << tir_prefix_ << ".block("; + if (!block_op->name_hint.empty()) { + doc << Doc::StrLiteral(block_op->name_hint); + } + doc << "):"; + Doc block_var = PrintBlockVars(op); // print predicate, binding, read/write tensor region, annotations Doc block_attr_doc = PrintBlockAttr(op); // print body Doc body = PrintBlockBody(block_op); - doc << Doc::Indent(4, block_attr_doc << Doc::NewLine() << body); + doc << Doc::Indent(4, block_var << block_attr_doc << Doc::NewLine() << body); for (const auto& iter_var : block_op->iter_vars) { TryDeallocVar(iter_var->var); } @@ -1265,11 +1329,11 @@ Doc TVMScriptPrinter::PrintLoop(const For& loop) { Doc TVMScriptPrinter::PrintLoopStack() { Doc res; - if (loop_stack_.size() == 1) { - res << PrintLoop(loop_stack_[0]); - } else if (loop_stack_.size() > 1) { + if (simple_loop_stack_.size() == 1) { + res << PrintLoop(simple_loop_stack_[0]); + } else if (simple_loop_stack_.size() > 1) { std::vector vars, extents; - for (const auto& loop : loop_stack_) { + for (const auto& loop : simple_loop_stack_) { vars.push_back(Print(loop->loop_var)); extents.push_back(Print(loop->extent)); } diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index 3c4604d18a0d..7e3d3d107507 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -44,21 +44,11 @@ class ScriptCompleter : public StmtMutator { Map* buffer_var_map_; Stmt VisitStmt_(const BlockRealizeNode* op) override { contains_block = true; - Stmt body = StmtMutator::VisitStmt_(op); - if (!op->iter_values.empty() && !op->iter_values[0].dtype().is_int()) { - auto block_with_binding = CopyOnWrite(Downcast(body).get()); - std::vector bindings; - for (size_t i = 0; i < op->iter_values.size(); ++i) { - bindings.push_back(Var("i" + std::to_string(i))); - } - block_with_binding->iter_values = bindings; - body = BlockRealize(block_with_binding); - for (int i = op->iter_values.size() - 1; i >= 0; --i) { - body = For(Downcast(bindings[i]), op->block->iter_vars[i]->dom->min, - op->block->iter_vars[i]->dom->extent, {}, body); - } + for (const PrimExpr& value : op->iter_values) { + CHECK(value.dtype().is_int()) + << "BlockRealize iter_value expected a IntImm, but got " << value.dtype(); } - return body; + return StmtMutator::VisitStmt_(op); } Stmt VisitStmt_(const BlockNode* op) override { diff --git a/tests/python/integration/test_lower.py b/tests/python/integration/test_lower.py index 690258c2fa3b..63733b05ab3f 100644 --- a/tests/python/integration/test_lower.py +++ b/tests/python/integration/test_lower.py @@ -33,9 +33,8 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: # body for blockIdx_x in T.thread_binding(0, 16, "blockIdx.x"): for blockIdx_y in T.thread_binding(0, 8, "blockIdx.y"): - with T.block([16, 8]) as [bx, by]: - T.bind(bx, blockIdx_x) - T.bind(by, blockIdx_y) + with T.block(): + bx, by = T.axis.remap("SS", [blockIdx_x, blockIdx_y]) shared_A = T.alloc_buffer([1024, 1024], "float16", scope="shared") shared_B = T.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_A = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") @@ -44,9 +43,9 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: for ty in T.thread_binding(0, 2, "threadIdx.y"): for tz in T.thread_binding(0, 2, "threadIdx.z"): for i, j in T.grid(2, 4): - with T.block([64, 64]) as [vi, vj]: - T.bind(vi, bx * 4 + ty * 2 + i) - T.bind(vj, by * 8 + tz * 4 + j) + with T.block(): + vi = T.axis.S(64, bx * 4 + ty * 2 + i) + vj = T.axis.S(64, by * 8 + tz * 4 + j) T.reads([]) T.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) C0 = T.match_buffer( @@ -74,23 +73,23 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: for tx in T.thread_binding(0, 32, "threadIdx.x"): for i0, j0 in T.grid(1, 4): for j1 in T.vectorized(0, 4): - with T.block([1024, 1024]) as [vi, vj]: - T.bind(vi, bx * 64 + ty * 32 + tx + i0) - T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + with T.block(): + vi = T.axis.S(1024, bx * 64 + ty * 32 + tx + i0) + vj = T.axis.S(1024, ko * 32 + tz * 16 + j0 * 4 + j1) shared_A[vi, vj + 8] = A[vi, vj] for i0, j0 in T.grid(2, 4): for j1 in T.vectorized(0, 4): - with T.block([1024, 1024]) as [vi, vj]: - T.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) - T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + with T.block(): + vi = T.axis.S(1024, by * 128 + ty * 64 + tx * 2 + i0) + vj = T.axis.S(1024, ko * 32 + tz * 16 + j0 * 4 + j1) shared_B[vi, vj + 8] = B[vi, vj] for ki in range(0, 2): for i in range(0, 2): - with T.block([64, 64]) as [vi, vk]: - T.bind(vi, bx * 4 + ty * 2 + i) - T.bind(vk, ko * 2 + ki) + with T.block(): + vi = T.axis.S(64, bx * 4 + ty * 2 + i) + vk = T.axis.S(64, ko * 2 + ki) T.reads( shared_A[ vi * 16 : vi * 16 + 16, @@ -142,9 +141,9 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: ) ) for j in range(0, 4): - with T.block([64, 64]) as [vj, vk]: - T.bind(vj, by * 8 + tz * 4 + j) - T.bind(vk, ko * 2 + ki) + with T.block(): + vj = T.axis.S(64, by * 8 + tz * 4 + j) + vk = T.axis.S(64, ko * 2 + ki) T.reads( shared_B[ vj * 16 : vj * 16 + 16, @@ -196,14 +195,10 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: ) ) for i, j in T.grid(2, 4): - with T.block([64, 64, T.reduce_axis(0, 64)]) as [ - vi, - vj, - vk, - ]: - T.bind(vi, bx * 4 + ty * 2 + i) - T.bind(vj, by * 8 + tz * 4 + j) - T.bind(vk, ko * 2 + ki) + with T.block(): + vi = T.axis.S(64, bx * 4 + ty * 2 + i) + vj = T.axis.S(64, by * 8 + tz * 4 + j) + vk = T.axis.R(64, ko * 2 + ki) T.reads( [ wmma_A[ @@ -258,9 +253,9 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: ) ) for i, j in T.grid(2, 4): - with T.block([64, 64]) as [vi, vj]: - T.bind(vi, bx * 4 + ty * 2 + i) - T.bind(vj, by * 8 + tz * 4 + j) + with T.block(): + vi = T.axis.S(64, bx * 4 + ty * 2 + i) + vj = T.axis.S(64, by * 8 + tz * 4 + j) T.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) s0 = T.var("int32") diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index 6502f0c67de6..fabf41705698 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -41,10 +41,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) - for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + for k in range(128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] diff --git a/tests/python/unittest/test_meta_schedule_arg_info.py b/tests/python/unittest/test_meta_schedule_arg_info.py index 7bedea9082d1..62dcb52f7415 100644 --- a/tests/python/unittest/test_meta_schedule_arg_info.py +++ b/tests/python/unittest/test_meta_schedule_arg_info.py @@ -28,10 +28,12 @@ def Matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 256), "float32") B = T.match_buffer(b, (256, 512), "float32") C = T.match_buffer(c, (128, 512), "float32") - with T.block([128, 256, T.reduce_axis(0, 512)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(128, 256, 512): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument diff --git a/tests/python/unittest/test_meta_schedule_builder.py b/tests/python/unittest/test_meta_schedule_builder.py index fa09a092c8c4..fb3fa135a9b8 100644 --- a/tests/python/unittest/test_meta_schedule_builder.py +++ b/tests/python/unittest/test_meta_schedule_builder.py @@ -47,10 +47,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @script.ir_module @@ -64,12 +66,16 @@ def matmul_relu( # pylint: disable=no-self-argument B = T.match_buffer(b, (1024, 1024), "float32") D = T.match_buffer(d, (1024, 1024), "float32") C = T.alloc_buffer((1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([1024, 1024], "relu") as [vi, vj]: - D[vi, vj] = T.max(C[vi, vj], 0.0) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(1024, 1024): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) @script.ir_module @@ -82,10 +88,12 @@ def batch_matmul( # pylint: disable=no-self-argument A = T.match_buffer(a, [16, 128, 128]) B = T.match_buffer(b, [16, 128, 128]) C = T.match_buffer(c, [16, 128, 128]) - with T.block([16, 128, 128, T.reduce_axis(0, 128)], "update") as [vn, vi, vj, vk]: - with T.init(): - C[vn, vi, vj] = 0.0 - C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + for n, i, j, k in T.grid(16, 128, 128, 128): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + with T.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py index cb39c91eaca4..121ec2fd480b 100644 --- a/tests/python/unittest/test_meta_schedule_database.py +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -41,10 +41,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @tvm.script.ir_module @@ -56,12 +58,16 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (16, 16), "float32") D = T.match_buffer(d, (16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([16, 16], "relu") as [vi, vj]: - D[vi, vj] = T.max(C[vi, vj], 0.0) + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(16, 16): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) # fmt: on diff --git a/tests/python/unittest/test_meta_schedule_runner.py b/tests/python/unittest/test_meta_schedule_runner.py index 9fb1e5ef19c1..46be12569c78 100644 --- a/tests/python/unittest/test_meta_schedule_runner.py +++ b/tests/python/unittest/test_meta_schedule_runner.py @@ -68,10 +68,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") - with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @tvm.script.ir_module @@ -83,12 +85,16 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (16, 16), "float32") D = T.match_buffer(d, (16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([16, 16], "relu") as [vi, vj]: - D[vi, vj] = T.max(C[vi, vj], 0.0) + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(16, 16): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) @tvm.script.ir_module @@ -99,10 +105,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, [16, 32, 32]) B = T.match_buffer(b, [16, 32, 32]) C = T.match_buffer(c, [16, 32, 32]) - with T.block([16, 32, 32, T.reduce_axis(0, 32)], "update") as [vn, vi, vj, vk]: - with T.init(): - C[vn, vi, vj] = 0.0 - C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + for n, i, j, k in T.grid(16, 32, 32, 32): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + with T.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] @tvm.script.ir_module @@ -113,8 +121,10 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, [32], "float32") B = T.match_buffer(b, [32], "float32") C = T.match_buffer(c, [32], "float32") - with T.block([32], "add") as [vi]: - C[vi] = A[vi] + B[vi] + for i in range(32): + with T.block("add"): + vi = T.axis.S(32, i) + C[vi] = A[vi] + B[vi] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index e12871391558..9b3ddfd7c789 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -45,10 +45,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (32, 32), "float32") B = T.match_buffer(b, (32, 32), "float32") C = T.match_buffer(c, (32, 32), "float32") - with T.block([32, 32, T.reduce_axis(0, 32)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(32, 32, 32): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 39bb1acf065f..3f7749ca9e2c 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -40,10 +40,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index a30409696543..4854aeb5f5aa 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -48,10 +48,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @tvm.script.ir_module @@ -63,12 +65,16 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (1024, 1024), "float32") D = T.match_buffer(d, (1024, 1024), "float32") C = T.alloc_buffer((1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([1024, 1024], "relu") as [vi, vj]: - D[vi, vj] = T.max(C[vi, vj], 0.0) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(1024, 1024): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) @tvm.script.ir_module @@ -79,10 +85,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, [16, 128, 128]) B = T.match_buffer(b, [16, 128, 128]) C = T.match_buffer(c, [16, 128, 128]) - with T.block([16, 128, 128, T.reduce_axis(0, 128)], "matmul") as [vn, vi, vj, vk]: - with T.init(): - C[vn, vi, vj] = 0.0 - C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + for n, i, j, k in T.grid(16, 128, 128, 128): + with T.block("matmul"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + with T.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks diff --git a/tests/python/unittest/test_meta_schedule_tune_context.py b/tests/python/unittest/test_meta_schedule_tune_context.py index 44bb949b925b..01a4379e5127 100644 --- a/tests/python/unittest/test_meta_schedule_tune_context.py +++ b/tests/python/unittest/test_meta_schedule_tune_context.py @@ -35,10 +35,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 987898001a1b..6b5c26d08b7b 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -54,10 +54,12 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: - with T.init(): - C[i, j] = 0.0 - C[i, j] += A[i, k] * B[j, k] + for i0, j0, k0 in T.grid(128, 128, 128): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] def test_matmul(): @@ -77,10 +79,14 @@ def tir_element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) B = T.alloc_buffer((128, 128)) - with T.block([128, 128]) as [i, j]: - B[i, j] = A[i, j] * 2.0 - with T.block([128, 128]) as [i, j]: - C[i, j] = B[i, j] + 1.0 + for i0, j0 in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) + B[i, j] = A[i, j] * 2.0 + for i0, j0 in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) + C[i, j] = B[i, j] + 1.0 def test_element_wise(): @@ -125,19 +131,21 @@ def tir_conv2d(a: T.handle, w: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [16, 32, 14, 14]) Apad = T.alloc_buffer([16, 16, 16, 16]) - with T.block([16, 16, 16, 16], "Apad") as [nn, cc, yy, xx]: - Apad[nn, cc, yy, xx] = T.if_then_else( - yy >= 1 and yy - 1 < 14 and xx >= 1 and xx - 1 < 14, - A[nn, cc, yy - 1, xx - 1], - 0.0, - dtype="float32", - ) - with T.block( - [16, 32, 14, 14, T.reduce_axis(0, 16), T.reduce_axis(0, 3), T.reduce_axis(0, 3)], "B" - ) as [nn, ff, yy, xx, rc, ry, rx]: - with T.init(): - B[nn, ff, yy, xx] = 0.0 - B[nn, ff, yy, xx] += Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff] + for n, c, y, x in T.grid(16, 16, 16, 16): + with T.block("Apad"): + nn, cc, yy, xx = T.axis.remap("SSSS", [n, c, y, x]) + Apad[nn, cc, yy, xx] = T.if_then_else( + yy >= 1 and yy - 1 < 14 and xx >= 1 and xx - 1 < 14, + A[nn, cc, yy - 1, xx - 1], + 0.0, + dtype="float32", + ) + for n, f, y, x, kc, ky, kx in T.grid(16, 32, 14, 14, 16, 3, 3): + with T.block("B"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [n, f, y, x, kc, ky, kx]) + with T.init(): + B[nn, ff, yy, xx] = 0.0 + B[nn, ff, yy, xx] += Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff] def test_conv2d(): @@ -163,9 +171,11 @@ def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> B1 = T.match_buffer(b1, (m, n)) for i0, i1 in T.grid(m, n): - with T.block([m, n], "B.v0") as [i, j]: + with T.block("B.v0"): + i, j = T.axis.remap("SS", [i0, i1]) B0[i, j] = A0[i, j] + 2.0 - with T.block([m, n], "B.v1") as [i, j]: + with T.block("B.v1"): + i, j = T.axis.remap("SS", [i0, i1]) B1[i, j] = A1[i, j] * 3.0 @@ -193,7 +203,7 @@ def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) # body - with T.block([], "C"): + with T.block("C"): T.reads([A[0:128, 0:128], B[0:128, 0:128]]) T.writes([C[0:128, 0:128]]) T.evaluate( @@ -251,10 +261,12 @@ def tir_reordered_matmul(c: T.handle, a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: - with T.init(): - C[i, j] = 0.0 - C[i, j] += A[i, k] * B[j, k] + for i0, j0, k0 in T.grid(128, 128, 128): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] def test_arg_order(): diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 1aae8cdd03e1..1a0dfd09a2df 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -25,17 +25,23 @@ def buffer_load_store_func(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128), "float32") C = T.alloc_buffer((128, 128), "float32") D = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128]) as [i, j]: - A[i, j] = T.float32(0) - with T.block([32, 32, T.reduce_axis(0, 32)]) as [i, j, k]: - with T.init(): + for ii, jj in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [ii, jj]) + A[i, j] = T.float32(0) + for i0, j0, k0 in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + for ii, jj in T.grid(4, 4): + B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in T.grid(4, 4): - B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] - for ii, jj in T.grid(4, 4): - for kk in range(0, 4): - B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] - for kk in range(0, 4): - B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += ( + D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + ) @T.prim_func @@ -43,7 +49,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [16, 16], "float32") C = T.match_buffer(c, [16, 16], "float32") - with T.block([]): + with T.block(): T.reads([]) T.writes(B[0:16, 0:16]) A = T.allocate([256], "float32", "global") @@ -56,9 +62,8 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, T.float32(0), dtype="handle")) for i, j in T.grid(16, 16): - with T.block([16, 16]) as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] @@ -72,16 +77,20 @@ def lca_is_func_root(a: T.handle) -> None: def match_buffer_func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") - with T.block([8, 8], "block") as [vi, vj]: - T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) - T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) - B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) - with T.block([16, 16], "AAA") as [i, j]: - AA = T.match_buffer(A[i, j], ()) - AA[()] = 1.0 - T.evaluate(B0.data) - T.evaluate(B1.data) + for i, j in T.grid(8, 8): + with T.block("block"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + for ii, jj in T.grid(16, 16): + with T.block("AAA"): + vii, vjj = T.axis.remap("SS", [ii, jj]) + AA = T.match_buffer(A[vii, vjj], ()) + AA[()] = 1.0 + T.evaluate(B0.data) + T.evaluate(B1.data) def test_buffer_load_store(): diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index e3a63c325434..4ea35c0a2d6c 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -27,57 +27,65 @@ def func() -> None: B = T.alloc_buffer((128, 128), "float32") C = T.alloc_buffer((128, 128), "float32") D = T.alloc_buffer((128, 128), "float32") - with T.block([]): + with T.block(): # Need add read/write region manually to avoid triggering block access region detector T.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]]) T.writes([A[0:12, 0:12]]) for i, j in T.grid(8, 8): A[i, j] = B[0, 0] + C[0, 0] - with T.block([2, 2]) as [vi, vj]: - T.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) - T.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) - for i, j in T.grid(4, 4): - A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12] + for i, j in T.grid(2, 2): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) + T.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) + for i, j in T.grid(4, 4): + A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12] T.evaluate(D.data) @T.prim_func def match_buffer_func() -> None: - with T.block([], "root"): + with T.block("root"): A = T.alloc_buffer((128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") T.reads([]) T.writes([]) # Need add read/write region manually to avoid triggering block access region detector - with T.block([8, 8], "block") as [vi, vj]: - T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) - T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - AA = T.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) - B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) - B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) - with T.block([16, 16], "AAA") as [i, j]: - T.reads([]) - T.writes(AA[i, j]) - AAA = T.match_buffer(AA[i, j], ()) - AAA[()] = 1.0 - T.evaluate(B0.data) - T.evaluate(B1.data) + for i, j in T.grid(8, 8): + with T.block("block"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + AA = T.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) + B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = T.match_buffer( + B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8) + ) + for ii, jj in T.grid(16, 16): + with T.block("AAA"): + vii, vjj = T.axis.remap("SS", [ii, jj]) + T.reads([]) + T.writes(AA[vii, vjj]) + AAA = T.match_buffer(AA[vii, vjj], ()) + AAA[()] = 1.0 + T.evaluate(B0.data) + T.evaluate(B1.data) @T.prim_func def opaque_block_func() -> None: - with T.block([], "root"): + with T.block("root"): A = T.alloc_buffer((16, 16), "float32") B = T.alloc_buffer((16, 16), "float32") T.reads([]) T.writes([]) # Need add read/write region manually to avoid triggering block access region detector for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes([B[i, 0:16]]) for j in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 @@ -88,8 +96,8 @@ def opaque_access_func() -> None: A = T.alloc_buffer([1024]) B = T.alloc_buffer([1024]) for i in T.serial(0, 8): - with T.block([8]) as [v]: - T.bind(v, i) + with T.block(): + v = T.axis.S(8, i) T.reads([A[v * 128 : v * 128 + 128]]) T.writes([B[v * 128 : v * 128 + 128]]) T.evaluate( diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 7129275aebcd..5ca9cf0da3c9 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -39,7 +39,7 @@ def buffer_load_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16)) C = T.match_buffer(c, (16, 16)) for i, j, k in T.grid(4, 16, 8): - with T.block([]): + with T.block(): T.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) T.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) sub_A = T.match_buffer( @@ -55,7 +55,7 @@ def transformed_buffer_load_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16)) C = T.match_buffer(c, (16, 16)) for i, j, k in T.grid(4, 16, 8): - with T.block([]): + with T.block(): T.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) T.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) for ii, kk in T.grid(4, 2): @@ -72,7 +72,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (32, 64, 128)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(2, 64, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) sub_A = T.match_buffer( @@ -93,7 +93,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: ) ) for i, j, k in T.grid(64, 2, 8): - with T.block([]): + with T.block(): Bs_0 = T.var("int32") Bs_1 = T.var("int32") T.reads([]) @@ -122,7 +122,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (32, 64, 128)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(2, 64, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) T.evaluate( @@ -137,7 +137,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: ) ) for i, j, k in T.grid(64, 2, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) T.evaluate( @@ -157,7 +157,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: def high_dim_opaque_access(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64)) for i, j, k in T.grid(16, 2, 4): - with T.block([]): + with T.block(): As_0 = T.var("int32") As_1 = T.var("int32") T.reads([]) @@ -185,7 +185,7 @@ def high_dim_opaque_access(a: T.handle) -> None: def transformed_high_dim_opaque_access(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64)) for i, j, k in T.grid(16, 2, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) T.evaluate( @@ -205,7 +205,7 @@ def transformed_high_dim_opaque_access(a: T.handle) -> None: def high_dim_opaque_access_with_source_strides(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): - with T.block([]): + with T.block(): As_0 = T.var("int32") As_1 = T.var("int32") T.reads([]) @@ -233,7 +233,7 @@ def high_dim_opaque_access_with_source_strides(a: T.handle) -> None: def transformed_high_dim_opaque_access_with_source_strides(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) T.evaluate( @@ -254,7 +254,7 @@ def recursive_match(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (64, 64, 64)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(64, 4, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes( [ @@ -276,7 +276,7 @@ def recursive_match(a: T.handle, b: T.handle) -> None: offset_factor=1, ) for jj, kk in T.grid(4, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes( [ @@ -317,7 +317,7 @@ def transformed_recursive_match(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (64, 64, 64)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(64, 4, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes( [ @@ -326,7 +326,7 @@ def transformed_recursive_match(a: T.handle, b: T.handle) -> None: ] ) for jj, kk in T.grid(4, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes( [ @@ -362,7 +362,7 @@ def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) -> None: A = T.match_buffer(a, (n * m, m)) B = T.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) Bs_0 = T.var("int32") @@ -392,7 +392,7 @@ def transformed_symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) A = T.match_buffer(a, (n * m, m)) B = T.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) for ii, jj in T.grid(m, m): @@ -416,7 +416,7 @@ def rank0_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (8, 8)) B = T.match_buffer(b, (8, 8)) for i, j in T.grid(8, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i, j], B[i, j]]) sub_A = T.match_buffer(A[i, j], (), offset_factor=1) @@ -440,7 +440,7 @@ def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (8, 8)) B = T.match_buffer(b, (8, 8)) for i, j in T.grid(8, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i, j], B[i, j]]) A[i, j] = 1 @@ -461,7 +461,7 @@ def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None: def fail_match_load(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 8): - with T.block([]): + with T.block(): T.reads(A[i, j]) T.writes([]) sub_A = T.match_buffer(A[i, j], ()) @@ -472,7 +472,7 @@ def fail_match_load(a: T.handle) -> None: def fail_match_store(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, j]) sub_A = T.match_buffer(A[i, j], ()) @@ -483,7 +483,7 @@ def fail_match_store(a: T.handle) -> None: def fail_buffer_bind(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): - with T.block([]): + with T.block(): stride = T.var("int32") sub_A = T.match_buffer( A[i, j * 4 : j * 4 + 4], (1, 4), strides=[stride, stride], offset_factor=1 @@ -496,7 +496,7 @@ def fail_buffer_bind(a: T.handle) -> None: def fail_match_func_param(a: T.handle, m: T.handle, n: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): - with T.block([]): + with T.block(): sub_A = T.match_buffer(A[i, j * 4 : j * 4 + 4], (1, 4), strides=[m, n], offset_factor=1) for jj in range(0, 4): sub_A[i, j * 4 + jj] = 1 diff --git a/tests/python/unittest/test_tir_schedule_block_scope.py b/tests/python/unittest/test_tir_schedule_block_scope.py index 2182c7b9f449..ad789a010745 100644 --- a/tests/python/unittest/test_tir_schedule_block_scope.py +++ b/tests/python/unittest/test_tir_schedule_block_scope.py @@ -32,10 +32,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -44,10 +48,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -58,9 +64,11 @@ def war_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index ff5b61a135eb..853f44affe5d 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -33,10 +33,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -45,20 +49,23 @@ def access_under_scope(b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([8, 8], "scope") as [i, j]: - for x, y in T.grid(16, 16): - with T.block([128, 128], "A") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A[vi, vj] = 1.0 - for x, y in T.grid(16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - B[vi, vj] = A[vi, vj] + 1.0 - - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + for i0, j0 in T.grid(8, 8): + with T.block("scope"): + i, j = T.axis.remap("SS", [i0, j0]) + for x, y in T.grid(16, 16): + with T.block("A"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A[vi, vj] = 1.0 + for x, y in T.grid(16, 16): + with T.block("B"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B[vi, vj] = A[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -68,76 +75,82 @@ def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: C = T.match_buffer(c, (128, 128), dtype="float16") D = T.match_buffer(d, (128, 128), dtype="float16") - with T.block([128, 128], "load_store") as [vi, vj]: - T.reads(A[vi, vj]) - T.writes(D[vi, vj]) - D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) - with T.block([8, 8], "opaque") as [vi, vj]: - T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.evaluate( - T.tvm_load_matrix_sync( - B.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A.data, - vi * 2048 + vj * 16, + for i, j in T.grid(128, 128): + with T.block("load_store"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(D[vi, vj]) + D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) + for i, j in T.grid(8, 8): + with T.block("opaque"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.evaluate( + T.tvm_load_matrix_sync( + B.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), 128, - 1, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) + ) + for i, j in T.grid(8, 8): + with T.block("match_buffer"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = T.match_buffer( + A[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + C0 = T.match_buffer( + C[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, ) - ) - with T.block([8, 8], "match_buffer") as [vi, vj]: - T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A0 = T.match_buffer( - A[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - C0 = T.match_buffer( - C[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - T.evaluate( - T.tvm_load_matrix_sync( - C0.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A0.data, - A0.elem_offset, - A0.strides[0], - 1, + T.evaluate( + T.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) ) - ) @T.prim_func @@ -147,15 +160,16 @@ def func_multi_consumer() -> None: C = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block([128], "A") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A"): + vi = T.axis.S(128, i * 16 + j) A[vi] = 1.0 for j in T.grid(16): - with T.block([128], "B") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("B"): + vi = T.axis.S(128, i * 16 + j) B[vi] = A[vi] + 1.0 for i in T.grid(128): - with T.block([128], "C") as [vi]: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = A[vi] @@ -163,12 +177,18 @@ def func_multi_consumer() -> None: def func_multi_producer() -> None: A = T.alloc_buffer((128)) B = T.alloc_buffer((128)) - with T.block([128], "A0") as [vi]: - A[vi] = 1.0 - with T.block([128], "A1") as [vi]: - A[vi] = 2.0 - with T.block([128], "B") as [vi]: - B[vi] = A[vi] + for i in range(128): + with T.block("A0"): + vi = T.axis.S(128, i) + A[vi] = 1.0 + for i in range(128): + with T.block("A1"): + vi = T.axis.S(128, i) + A[vi] = 2.0 + for i in range(128): + with T.block("B"): + vi = T.axis.S(128, i) + B[vi] = A[vi] ########## Expected function after cache_read ########## @@ -181,14 +201,22 @@ def cache_read_elementwise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) A_global = T.alloc_buffer((128, 128)) B_local = T.alloc_buffer((128, 128), scope="local") - with T.block([128, 128], "A_global") as [vi, vj]: - A_global[vi, vj] = A[vi, vj] - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A_global[vi, vj] * 2.0 - with T.block([128, 128], "B_local") as [vi, vj]: - B_local[vi, vj] = B[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B_local[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("A_global"): + vi, vj = T.axis.remap("SS", [i, j]) + A_global[vi, vj] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A_global[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B_local"): + vi, vj = T.axis.remap("SS", [i, j]) + B_local[vi, vj] = B[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B_local[vi, vj] + 1.0 @T.prim_func @@ -198,27 +226,33 @@ def cache_read_under_scope(b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) A_global = T.alloc_buffer((128, 128)) - with T.block([8, 8], "scope") as [i, j]: - A_local = T.alloc_buffer((128, 128), scope="local") - for x, y in T.grid(16, 16): - with T.block([128, 128], "A") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A[vi, vj] = 1.0 - for x, y in T.grid(16, 16): - with T.block([128, 128], "A_local") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A_local[vi, vj] = A[vi, vj] - for x, y in T.grid(16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - B[vi, vj] = A_local[vi, vj] + 1.0 - with T.block([128, 128], "A_global") as [vi, vj]: - A_global[vi, vj] = A[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A_global[vi, vj] * 2.0 + for i0, j0 in T.grid(8, 8): + with T.block("scope"): + i, j = T.axis.remap("SS", [i0, j0]) + A_local = T.alloc_buffer((128, 128), scope="local") + for x, y in T.grid(16, 16): + with T.block("A"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A[vi, vj] = 1.0 + for x, y in T.grid(16, 16): + with T.block("A_local"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A_local[vi, vj] = A[vi, vj] + for x, y in T.grid(16, 16): + with T.block("B"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B[vi, vj] = A_local[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("A_global"): + vi, vj = T.axis.remap("SS", [i, j]) + A_global[vi, vj] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A_global[vi, vj] * 2.0 @T.prim_func @@ -229,78 +263,86 @@ def cache_read_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) D = T.match_buffer(d, (128, 128), dtype="float16") A_global = T.alloc_buffer((128, 128), dtype="float16") - with T.block([128, 128], "A_global") as [vi, vj]: - A_global[vi, vj] = A[vi, vj] - with T.block([128, 128], "load_store") as [vi, vj]: - T.reads(A_global[vi, vj]) - T.writes(D[vi, vj]) - D.data[vi * 128 + vj] = T.load("float16", A_global.data, vi * 128 + vj) - with T.block([8, 8], "opaque") as [vi, vj]: - T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.evaluate( - T.tvm_load_matrix_sync( - B.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A_global.data, - vi * 2048 + vj * 16, + for i, j in T.grid(128, 128): + with T.block("A_global"): + vi, vj = T.axis.remap("SS", [i, j]) + A_global[vi, vj] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("load_store"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A_global[vi, vj]) + T.writes(D[vi, vj]) + D.data[vi * 128 + vj] = T.load("float16", A_global.data, vi * 128 + vj) + for i, j in T.grid(8, 8): + with T.block("opaque"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.evaluate( + T.tvm_load_matrix_sync( + B.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A_global.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), 128, - 1, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) + ) + for i, j in T.grid(8, 8): + with T.block("match_buffer"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = T.match_buffer( + A_global[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, ) - ) - with T.block([8, 8], "match_buffer") as [vi, vj]: - T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A0 = T.match_buffer( - A_global[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - C0 = T.match_buffer( - C[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - T.evaluate( - T.tvm_load_matrix_sync( - C0.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A0.data, - A0.elem_offset, - A0.strides[0], - 1, + C0 = T.match_buffer( + C[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + T.evaluate( + T.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) ) - ) @T.prim_func @@ -311,20 +353,21 @@ def cache_read_multi_consumer() -> None: A_global = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block([128], "A") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A"): + vi = T.axis.S(128, i * 16 + j) A[vi] = 1.0 for j in T.grid(16): - with T.block([128], "A") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A"): + vi = T.axis.S(128, i * 16 + j) A_global[vi] = A[vi] for j in T.grid(16): - with T.block([128], "B") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("B"): + vi = T.axis.S(128, i * 16 + j) B[vi] = A_global[vi] + 1.0 for i in T.grid(128): - with T.block([128], "C") as [vi]: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = A_global[vi] @@ -335,14 +378,22 @@ def continuous_cache_read(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) B_shared = T.alloc_buffer((128, 128), scope="shared") B_local = T.alloc_buffer((128, 128), scope="local") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "B_shared") as [vi, vj]: - B_shared[vi, vj] = B[vi, vj] - with T.block([128, 128], "B_local") as [vi, vj]: - B_local[vi, vj] = B_shared[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B_local[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B_shared"): + vi, vj = T.axis.remap("SS", [i, j]) + B_shared[vi, vj] = B[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B_local"): + vi, vj = T.axis.remap("SS", [i, j]) + B_local[vi, vj] = B_shared[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B_local[vi, vj] + 1.0 ########## Expected function after cache_write ########## @@ -355,14 +406,22 @@ def cache_write_elementwise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) B_global = T.alloc_buffer((128, 128), scope="local") C_local = T.alloc_buffer((128, 128)) - with T.block([128, 128], "B_global") as [vi, vj]: - B_global[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = B_global[vi, vj] - with T.block([128, 128], "C_local") as [vi, vj]: - C_local[vi, vj] = B[vi, vj] + 1.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = C_local[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B_global"): + vi, vj = T.axis.remap("SS", [i, j]) + B_global[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = B_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C_local"): + vi, vj = T.axis.remap("SS", [i, j]) + C_local[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = C_local[vi, vj] @T.prim_func @@ -372,33 +431,39 @@ def cache_write_under_scope(b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) A_global = T.alloc_buffer((128, 128)) - with T.block([8, 8], "scope") as [i, j]: - A_local = T.alloc_buffer((128, 128), scope="local") - B_global = T.alloc_buffer((128, 128)) - for x, y in T.grid(16, 16): - with T.block([128, 128], "A_local") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A_local[vi, vj] = 1.0 - for x, y in T.grid(16, 16): - with T.block([128, 128], "A") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A_global[vi, vj] = A_local[vi, vj] - for x, y in T.grid(16, 16): - with T.block([128, 128], "B_global") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - B_global[vi, vj] = A_global[vi, vj] + 1.0 - for x, y in T.grid(16, 16): - with T.block([128, 128], "B_global") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - B[vi, vj] = B_global[vi, vj] - with T.block([128, 128], "A_global") as [vi, vj]: - A[vi, vj] = A_global[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + for i0, j0 in T.grid(8, 8): + with T.block("scope"): + i, j = T.axis.remap("SS", [i0, j0]) + A_local = T.alloc_buffer((128, 128), scope="local") + B_global = T.alloc_buffer((128, 128)) + for x, y in T.grid(16, 16): + with T.block("A_local"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A_local[vi, vj] = 1.0 + for x, y in T.grid(16, 16): + with T.block("A"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A_global[vi, vj] = A_local[vi, vj] + for x, y in T.grid(16, 16): + with T.block("B_global"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B_global[vi, vj] = A_global[vi, vj] + 1.0 + for x, y in T.grid(16, 16): + with T.block("B_global"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B[vi, vj] = B_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("A_global"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -411,83 +476,95 @@ def cache_write_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle B_global = T.alloc_buffer((128, 128), dtype="float16") C_global = T.alloc_buffer((128, 128), dtype="float16") - with T.block([128, 128], "load_store") as [vi, vj]: - T.reads(A[vi, vj]) - T.writes(D_global[vi, vj]) - D_global.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) - with T.block([8, 8], "opaque") as [vi, vj]: - T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.evaluate( - T.tvm_load_matrix_sync( - B_global.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A.data, - vi * 2048 + vj * 16, + for i, j in T.grid(128, 128): + with T.block("load_store"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(D_global[vi, vj]) + D_global.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) + for i, j in T.grid(8, 8): + with T.block("opaque"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.evaluate( + T.tvm_load_matrix_sync( + B_global.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), 128, - 1, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) + ) + for i, j in T.grid(8, 8): + with T.block("match_buffer"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = T.match_buffer( + A[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, ) - ) - with T.block([8, 8], "match_buffer") as [vi, vj]: - T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A0 = T.match_buffer( - A[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - C0 = T.match_buffer( - C_global[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - T.evaluate( - T.tvm_load_matrix_sync( - C0.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A0.data, - A0.elem_offset, - A0.strides[0], - 1, + C0 = T.match_buffer( + C_global[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + T.evaluate( + T.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) ) - ) - with T.block([128, 128], "D") as [vi, vj]: - D[vi, vj] = D_global[vi, vj] - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = B_global[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = C_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = D_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = B_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = C_global[vi, vj] @T.prim_func @@ -498,20 +575,21 @@ def cache_write_multi_consumer() -> None: A_global = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block([128], "A_global") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A_global"): + vi = T.axis.S(128, i * 16 + j) A_global[vi] = 1.0 for j in T.grid(16): - with T.block([128], "A") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A"): + vi = T.axis.S(128, i * 16 + j) A[vi] = A_global[vi] for j in T.grid(16): - with T.block([128], "B") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("B"): + vi = T.axis.S(128, i * 16 + j) B[vi] = A[vi] + 1.0 for i in T.grid(128): - with T.block([128], "C") as [vi]: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = A[vi] @@ -522,14 +600,22 @@ def continuous_cache_write(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) B_shared = T.alloc_buffer((128, 128), scope="shared") B_local = T.alloc_buffer((128, 128), scope="local") - with T.block([128, 128], "B") as [vi, vj]: - B_local[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "B") as [vi, vj]: - B_shared[vi, vj] = B_local[vi, vj] - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = B_shared[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B_local[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B_shared[vi, vj] = B_local[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = B_shared[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 ########## Testcases for cache_read ########## diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 5235664595ad..6e956e1ee688 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -32,10 +32,15 @@ def two_elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -45,12 +50,13 @@ def two_elementwise_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in range(0, 128): for ax0, ax1 in T.grid(1, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i + ax0) - T.bind(vj, ax1) + with T.block("B"): + vi = T.axis.S(128, i + ax0) + vj = T.axis.S(128, ax1) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -59,22 +65,26 @@ def blockized_1(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128], "float32") B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([8, 8], "C_outer") as [vi_o, vj_o]: - T.reads([B[ - vi_o * 16 : vi_o * 16 + 16, - vj_o * 16 : vj_o * 16 + 16, - ]]) - T.writes([C[ - vi_o * 16 : vi_o * 16 + 16, - vj_o * 16 : vj_o * 16 + 16 - ]]) - for i_i, j_i in T.grid(16, 16): - with T.block([128, 128], "C_inner") as [vi, vj]: - T.bind(vi, vi_o * 16 + i_i) - T.bind(vj, vj_o * 16 + j_i) - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(8, 8): + with T.block("C_outer"): + vi_o, vj_o = T.axis.remap("SS", [i, j]) + T.reads([B[ + vi_o * 16 : vi_o * 16 + 16, + vj_o * 16 : vj_o * 16 + 16, + ]]) + T.writes([C[ + vi_o * 16 : vi_o * 16 + 16, + vj_o * 16 : vj_o * 16 + 16 + ]]) + for i_i, j_i in T.grid(16, 16): + with T.block("C_inner"): + vi = T.axis.S(128, vi_o * 16 + i_i) + vj = T.axis.S(128, vj_o * 16 + j_i) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -84,13 +94,12 @@ def blockized_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], "float32") for i0_0, i1_0 in T.grid(8, 8): for ax0, ax1 in T.grid(16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0_0 * 16 + ax0) - T.bind(vj, i1_0 * 16 + ax1) + with T.block("B"): + vi = T.axis.S(128, i0_0 * 16 + ax0) + vj = T.axis.S(128, i1_0 * 16 + ax1) B[vi, vj] = A[vi, vj] * 2.0 - with T.block([8, 8], "C_outer") as [vi_o, vj_o]: - T.bind(vi_o, i0_0) - T.bind(vj_o, i1_0) + with T.block("C_outer"): + vi_o, vj_o = T.axis.remap("SS", [i0_0, i1_0]) T.reads([B[ vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16, @@ -100,9 +109,9 @@ def blockized_after_compute_at(a: T.handle, c: T.handle) -> None: vj_o * 16 : vj_o * 16 + 16 ]]) for i0_1, i1_1 in T.grid(16, 16): - with T.block([128, 128], "C_inner") as [vi, vj]: - T.bind(vi, vi_o * 16 + i0_1) - T.bind(vj, vj_o * 16 + i1_1) + with T.block("C_inner"): + vi = T.axis.S(128, vi_o * 16 + i0_1) + vj = T.axis.S(128, vj_o * 16 + i1_1) C[vi, vj] = B[vi, vj] + 1.0 @@ -112,9 +121,8 @@ def blockized_2(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(8, 8): - with T.block([8, 8], "B_outer") as [vio, vjo]: - T.bind(vio, i_o) - T.bind(vjo, j_o) + with T.block("B_outer"): + vio, vjo = T.axis.remap("SS", [i_o, j_o]) T.reads([A[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, @@ -124,14 +132,14 @@ def blockized_2(a: T.handle, c: T.handle) -> None: vjo * 16 : vjo * 16 + 16 ]]) for i_i, j_i in T.grid(16, 16): - with T.block([128, 128], "B_inner") as [vi, vj]: - T.bind(vi, vio * 16 + i_i) - T.bind(vj, vjo * 16 + j_i) + with T.block("B_inner"): + vi = T.axis.S(128, vio * 16 + i_i) + vj = T.axis.S(128, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for i_o, j_o, i_i, j_i in T.grid(4, 4, 32, 32): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_o * 32 + i_i) - T.bind(vj, j_o * 32 + j_i) + with T.block("C"): + vi = T.axis.S(128, i_o * 32 + i_i) + vj = T.axis.S(128, j_o * 32 + j_i) C[vi, vj] = B[vi, vj] + 1.0 @@ -141,9 +149,8 @@ def blockized_2_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(8, 8): - with T.block([8, 8], "B_outer") as [vio, vjo]: - T.bind(vio, i_o) - T.bind(vjo, j_o) + with T.block("B_outer"): + vio, vjo = T.axis.remap("SS", [i_o, j_o]) T.reads([A[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, @@ -153,14 +160,14 @@ def blockized_2_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: vjo * 16 : vjo * 16 + 16 ]]) for i_i, j_i in T.grid(16, 16): - with T.block([128, 128], "B_inner") as [vi, vj]: - T.bind(vi, vio * 16 + i_i) - T.bind(vj, vjo * 16 + j_i) + with T.block("B_inner"): + vi = T.axis.S(128, vio * 16 + i_i) + vj = T.axis.S(128, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for ax0, ax1 in T.grid(16, 16): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_o * 16 + ax0) - T.bind(vj, j_o * 16 + ax1) + with T.block("C"): + vi = T.axis.S(128, i_o * 16 + ax0) + vj = T.axis.S(128, j_o * 16 + ax1) T.reads([B[vi, vj]]) T.writes([C[vi, vj]]) C[vi, vj] = B[vi, vj] + 1.0 @@ -173,9 +180,9 @@ def blockized_2_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(4, 4): for ax0, ax1 in T.grid(2, 2): - with T.block([8, 8], "blockized_B") as [vio, vjo]: - T.bind(vio, i_o * 2 + ax0) - T.bind(vjo, j_o * 2 + ax1) + with T.block("blockized_B"): + vio = T.axis.S(8, i_o * 2 + ax0) + vjo = T.axis.S(8, j_o * 2 + ax1) T.reads([A[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, @@ -185,14 +192,14 @@ def blockized_2_after_compute_at(a: T.handle, c: T.handle) -> None: vjo * 16 : vjo * 16 + 16, ]]) for i_i, j_i in T.grid(16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, vio * 16 + i_i) - T.bind(vj, vjo * 16 + j_i) + with T.block("B"): + vi = T.axis.S(128, vio * 16 + i_i) + vj = T.axis.S(128, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for i_i, j_i in T.grid(32, 32): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_o * 32 + i_i) - T.bind(vj, j_o * 32 + j_i) + with T.block("C"): + vi = T.axis.S(128, i_o * 32 + i_i) + vj = T.axis.S(128, j_o * 32 + j_i) C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -205,18 +212,28 @@ def cuda_matmul_0(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - A_shared_local[v0, v1] = A_shared[v0, v1] - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - B_shared_local[v0, v1] = B_shared[v0, v1] - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - with T.init(): - C_local[vi, vj] = 0.0 - C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared_local[v0, v1] = A_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared_local[v0, v1] = B_shared[v0, v1] + for i, j, k in T.grid(2048, 2048, 2048): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C_local[vi, vj] = 0.0 + C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -224,9 +241,9 @@ def cuda_matmul_0(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0_4, v1_4]: - T.bind(v0_4, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1_4, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0_4 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1_4 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0_4, v1_4] = C_local[v0_4, v1_4] @@ -240,14 +257,22 @@ def cuda_matmul_0_after_compute_at(a: T.handle, b: T.handle, c: T.handle) -> Non A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - A_shared_local[v0, v1] = A_shared[v0, v1] - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - B_shared_local[v0, v1] = B_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared_local[v0, v1] = A_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -255,17 +280,17 @@ def cuda_matmul_0_after_compute_at(a: T.handle, b: T.handle, c: T.handle) -> Non for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for i, j, k in T.grid(4, 4, 2048): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [vi, vj]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj] @@ -279,14 +304,22 @@ def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - A_shared_local[v0, v1] = A_shared[v0, v1] - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - B_shared_local[v0, v1] = B_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared_local[v0, v1] = A_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -296,17 +329,17 @@ def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k_0 * 8 + k_1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k_0 * 8 + k_1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [vi, vj]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj] @@ -320,12 +353,18 @@ def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - B_shared_local[v0, v1] = B_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -335,22 +374,22 @@ def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - T.bind(v0, k_0 * 8 + k_1 + i) - T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + with T.block("A_shared_local"): + v0 = T.axis.S(2048, k_0 * 8 + k_1 + i) + v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k_0 * 8 + k_1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k_0 * 8 + k_1) with T.init(): C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0, v1]: - T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -364,10 +403,14 @@ def cuda_matmul_3(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -377,27 +420,27 @@ def cuda_matmul_3(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for k0 in T.serial(0, 256): for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + with T.block("A_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("B_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k0 * 8 + k1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) with T.init(): C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0, v1]: - T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -411,8 +454,10 @@ def cuda_matmul_4(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -421,33 +466,33 @@ def cuda_matmul_4(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for k0 in T.serial(0, 256): for i, j in T.grid(8, 64): - with T.block([2048, 2048], "A_shared") as [v0, v1]: - T.bind(v0, k0 * 8 + i) - T.bind(v1, by * 64 + j) + with T.block("A_shared"): + v0 = T.axis.S(2048, k0 * 8 + i) + v1 = T.axis.S(2048, by * 64 + j) A_shared[v0, v1] = A[v0, v1] for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + with T.block("A_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("B_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k0 * 8 + k1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0, v1]: - T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -469,38 +514,38 @@ def cuda_matmul_5(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for k0 in T.serial(0, 256): for i, j in T.grid(8, 64): - with T.block([2048, 2048], "A_shared") as [v0, v1]: - T.bind(v0, k0 * 8 + i) - T.bind(v1, by * 64 + j) + with T.block("A_shared"): + v0 = T.axis.S(2048, k0 * 8 + i) + v1 = T.axis.S(2048, by * 64 + j) A_shared[v0, v1] = A[v0, v1] for i, j in T.grid(8, 64): - with T.block([2048, 2048], "B_shared") as [v0, v1]: - T.bind(v0, k0 * 8 + i) - T.bind(v1, bx * 64 + j) + with T.block("B_shared"): + v0 = T.axis.S(2048, k0 * 8 + i) + v1 = T.axis.S(2048, bx * 64 + j) B_shared[v0, v1] = B[v0, v1] for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + with T.block("A_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("B_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k0 * 8 + k1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0, v1]: - T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -510,12 +555,14 @@ def tiled(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i_0 * 16 + i_1) - T.bind(vj, j_0 * 16 + j_1) + with T.block("B"): + vi = T.axis.S(128, i_0 * 16 + i_1) + vj = T.axis.S(128, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -525,14 +572,14 @@ def tiled_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], "float32") for i_0, j_0, i_1 in T.grid(8, 8, 16): for j_1 in T.serial(0, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i_0 * 16 + i_1) - T.bind(vj, j_0 * 16 + j_1) + with T.block("B"): + vi = T.axis.S(128, i_0 * 16 + i_1) + vj = T.axis.S(128, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 for j_1 in T.serial(0, 16): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_0 * 16 + i_1) - T.bind(vj, j_0 * 16 + j_1) + with T.block("C"): + vi = T.axis.S(128, i_0 * 16 + i_1) + vj = T.axis.S(128, j_0 * 16 + j_1) C[vi, vj] = B[vi, vj] + 1.0 @@ -544,17 +591,15 @@ def factorized(a: T.handle, b: T.handle) -> None: for j in T.thread_binding(0, 16, thread = "blockIdx.x"): for i_o in T.thread_binding(0, 4, thread = "threadIdx.x"): for i_i, k in T.grid(4, 16): - with T.block([16, 16, T.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: - T.bind(vi, i_o * 4 + i_i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B_rf"): + vi = T.axis.S(16, i_o * 4 + i_i) + vj, vk = T.axis.remap("SR", [j, k]) with T.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for i, k in T.grid(16, 16): - with T.block([16, T.reduce_axis(0, 16)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, k) + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi] @@ -568,17 +613,17 @@ def factorized_after_reverse_compute_at(a: T.handle, b: T.handle) -> None: for j in T.thread_binding(0, 16, thread = "blockIdx.x"): for i_o in T.thread_binding(0, 4, thread = "threadIdx.x"): for i_i, k in T.grid(4, 16): - with T.block([16, 16, T.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: - T.bind(vi, i_o * 4 + i_i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B_rf"): + vi = T.axis.S(16, i_o * 4 + i_i) + vj = T.axis.S(16, j) + vk = T.axis.R(16, k) with T.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for k in T.serial(0, 4): - with T.block([16, T.reduce_axis(0, 16)], "B") as [vi, vk]: - T.bind(vi, j) - T.bind(vk, i_o * 4 + k) + with T.block("B"): + vi = T.axis.S(16, j) + vk = T.axis.R(16, i_o * 4 + k) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi] @@ -591,17 +636,19 @@ def fail_subtree_compact_dataflow(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in range(0, 128): for j in range(0, 64): - with T.block([128, 128], "B_0") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B_0"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 64): - with T.block([128, 128], "B_1") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j + 64) + with T.block("B_1"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j + 64) B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -611,13 +658,16 @@ def fail_all_consumers_under_loop(a: T.handle, c: T.handle, d: T.handle) -> None C = T.match_buffer(c, (128, 128), "float32") D = T.match_buffer(d, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "D") as [vi, vj]: + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = B[vi, vj] + 1.0 @@ -628,13 +678,16 @@ def fail_all_producers_under_loop(a: T.handle, d: T.handle) -> None: C = T.alloc_buffer((128, 128), "float32") D = T.match_buffer(d, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "D") as [vi, vj]: + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = B[vi, vj] + C[vi, vj] @@ -644,10 +697,12 @@ def read_out_of_bound(a: T.handle, c:T.handle) -> None: B = T.alloc_buffer([16], "float32") C = T.match_buffer(c, [16], "float32") for i in T.serial(0, 16): - with T.block([16], "B") as [v]: + with T.block("B"): + v = T.axis.S(16, i) B[v] = A[v] for j in T.serial(0, 16): - with T.block([16], "C") as [v]: + with T.block("C"): + v = T.axis.S(16, j) T.reads(B[v : v + 2]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") @@ -659,11 +714,11 @@ def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [16], "float32") for j in T.serial(0, 16): for i in T.serial(0, T.min(1, 15 - j) + 1): - with T.block([16], "B") as [v]: - T.bind(v, j + i) + with T.block("B"): + v = T.axis.S(16, j + i) B[v] = A[v] - with T.block([16], "C") as [v]: - T.bind(v, j) + with T.block("C"): + v = T.axis.S(16, j) T.reads([B[v : v + 2]]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index f9049f6da732..617c75b75cd9 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -31,10 +31,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -43,12 +47,18 @@ def elementwise_multi_producer_consumer(a: T.handle, c: T.handle, d: T.handle) - B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) D = T.match_buffer(d, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 # B has two consumers - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 - with T.block([128, 128], "D") as [vi, vj]: - D[vi, vj] = B[vi, vj] + 2.0 + C[vi, vj] # D has two producers + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 # B has two consumers + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = B[vi, vj] + 2.0 + C[vi, vj] # D has two producers @T.prim_func @@ -56,10 +66,14 @@ def elementwise_multi_consumer_inlined(a: T.handle, c: T.handle, d: T.handle) -> A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) D = T.match_buffer(d, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 - with T.block([128, 128], "D") as [vi, vj]: - D[vi, vj] = A[vi, vj] * 2.0 + 2.0 + C[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = A[vi, vj] * 2.0 + 2.0 + C[vi, vj] @T.prim_func @@ -67,18 +81,24 @@ def elementwise_standalone(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] + 1.0 @T.prim_func def elementwise_standalone_dce(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] + 1.0 @T.prim_func @@ -88,14 +108,12 @@ def elementwise_under_loop(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -103,8 +121,10 @@ def elementwise_under_loop(a: T.handle, c: T.handle) -> None: def elementwise_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @T.prim_func @@ -113,11 +133,15 @@ def fail_multi_reader_writer(a: T.handle, d: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.alloc_buffer((128, 128)) D = T.match_buffer(d, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - C[vi, vj] = A[vi, vj] + 2.0 - with T.block([128, 128], "C") as [vi, vj]: - D[vi, vj] = B[vi, vj] + C[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + C[vi, vj] = A[vi, vj] + 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = B[vi, vj] + C[vi, vj] @T.prim_func @@ -125,18 +149,24 @@ def elementwise_multi_reverse_loads(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = (B[vi, vj] + 1.0) * (B[vi, vj] * 2.0) + 3.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = (B[vi, vj] + 1.0) * (B[vi, vj] * 2.0) + 3.0 @T.prim_func def elementwise_multi_reverse_loads_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0 @T.prim_func @@ -144,12 +174,16 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - T.reads(B[0:128, 0:128]) - T.writes(C[0:128, 0:128]) - C[vi, vj] = T.load("float32", B.data, vi * 128 + vj) + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[0:128, 0:128]) + T.writes(C[0:128, 0:128]) + C[vi, vj] = T.load("float32", B.data, vi * 128 + vj) + 1.0 @T.prim_func @@ -157,13 +191,17 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - T.reads(B[0:128, 0:128]) - T.writes(C[0:128, 0:128]) - T.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) - C[vi, vj] = T.load("float32", B.data, vi * 16 + vj) + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[0:128, 0:128]) + T.writes(C[0:128, 0:128]) + T.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) + C[vi, vj] = T.load("float32", B.data, vi * 16 + vj) + 1.0 @T.prim_func @@ -171,11 +209,15 @@ def buffer_matched(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - Bb = T.match_buffer(B[vi : vi + 1, vj], (1, 1)) - C[vi, vj] = Bb[0, 0] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + Bb = T.match_buffer(B[vi : vi + 1, vj], (1, 1)) + C[vi, vj] = Bb[0, 0] + 1.0 @T.prim_func @@ -183,10 +225,13 @@ def elementwise_predicate(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) T.where(B[i, j] < 10.0) C[vi, vj] = B[vi, vj] + 1.0 @@ -196,7 +241,8 @@ def elementwise_predicate_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) T.where(A[i, j] * 2.0 < 10.0) C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @@ -206,18 +252,24 @@ def elementwise_multi_loads(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 126], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + B[vi, vj + 1] + B[vi, vj + 2] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + B[vi, vj + 1] + B[vi, vj + 2] @T.prim_func def elementwise_multi_loads_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 126], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0 # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_error.py b/tests/python/unittest/test_tir_schedule_error.py index 7a9c8e01d355..ad6a1931bb0b 100644 --- a/tests/python/unittest/test_tir_schedule_error.py +++ b/tests/python/unittest/test_tir_schedule_error.py @@ -31,10 +31,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) - for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + for k in range(128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index 60269ac01c14..9075e93b9d45 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -31,9 +31,10 @@ def element_wise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) - - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -42,9 +43,8 @@ def element_wise_parallelized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i0 in T.parallel(0, 128): for i1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, i1) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, i1]) B[vi, vj] = A[vi, vj] * 2.0 @@ -54,9 +54,8 @@ def element_wise_i_bound(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i0 in T.thread_binding(0, 128, thread="threadIdx.x"): for i1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, i1) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, i1]) B[vi, vj] = A[vi, vj] * 2.0 @@ -67,14 +66,13 @@ def element_wise_compute_at_split(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j0) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o, j1i in T.grid(32, 4): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j1o * 4 + j1i) + with T.block("C"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 @@ -85,15 +83,14 @@ def element_wise_compute_at_split_vectorized(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j0) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o in T.serial(0, 32): for j1i in T.vectorized(0, 4): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j1o * 4 + j1i) + with T.block("C"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 @@ -102,10 +99,10 @@ def element_wise_split_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i, j_0, j_1 in T.grid(128, 13, 10): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): T.where(j_0 * 10 + j_1 < 128) - T.bind(vi, i) - T.bind(vj, j_0 * 10 + j_1) + vi = T.axis.S(128, i) + vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0 @@ -116,10 +113,10 @@ def element_wise_split_predicate_parallelized(a: T.handle, b: T.handle) -> None: for i in T.serial(0, 128): for j_0 in T.parallel(0, 13): for j_1 in T.serial(0, 10): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): T.where(j_0 * 10 + j_1 < 128) - T.bind(vi, i) - T.bind(vj, j_0 * 10 + j_1) + vi = T.axis.S(128, i) + vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0 @@ -129,10 +126,10 @@ def element_wise_split_predicate_vectorized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128]) for i in T.vectorized(0, 128): for j_0, j_1 in T.grid(13, 10): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): T.where(j_0 * 10 + j_1 < 128) - T.bind(vi, i) - T.bind(vj, j_0 * 10 + j_1) + vi = T.axis.S(128, i) + vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0 @@ -143,15 +140,14 @@ def element_wise_compute_at_split_j0_j1o_bound(a: T.handle, c: T.handle) -> None B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j0) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o in T.thread_binding(0, 32, thread="threadIdx.x"): for j1i in T.serial(0, 4): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j1o * 4 + j1i) + with T.block("C"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 @@ -161,10 +157,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -172,10 +170,12 @@ def rowsum(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] @T.prim_func @@ -184,9 +184,8 @@ def rowsum_unrolled(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i0 in T.unroll(0, 128): for i1 in T.serial(0, 128): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i0) - T.bind(vk, i1) + with T.block("B"): + vi, vk = T.axis.remap("SR", [i0, i1]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -198,9 +197,9 @@ def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 16): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, T.floordiv(k * k, 2)) + with T.block("B"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, T.floordiv(k * k, 2)) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -211,10 +210,12 @@ def rowsum_not_compact_data_flow(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vk] = 0.0 - B[vk] = B[vk] + A[vi, vk] + for i, k in T.grid(128, 16): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vk] = 0.0 + B[vk] = B[vk] + A[vi, vk] @T.prim_func @@ -223,9 +224,8 @@ def rowsum_cross_thread_reduction(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i0 in T.serial(0, 128): for i1 in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i0) - T.bind(vk, i1) + with T.block("B"): + vi, vk = T.axis.remap("SR", [i0, i1]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -235,7 +235,7 @@ def rowsum_cross_thread_reduction(a: T.handle, b: T.handle) -> None: def opaque_block(a: T.handle) -> None: A = T.match_buffer(a, (16,)) for i in T.serial(0, 15): - with T.block([], "opaque"): + with T.block("opaque"): A[i + 1] = A[i + 1] + A[i] diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index 8460b5cf3e66..e158f6a026e1 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -32,18 +32,17 @@ def rowsum_blockized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [32, 4]) A = T.match_buffer(a, [32, 4, 128]) for i0, i2_0 in T.grid(32, 16): - with T.block([32, T.reduce_axis(0, 16)], "blockized_B") as [io, ko]: - T.bind(io, i0) - T.bind(ko, i2_0) + with T.block("blockized_B"): + io, ko = T.axis.remap("SR", [i0, i2_0]) with T.init(): for i1 in T.serial(0, 4): - with T.block([4], "B_init") as [ii_init]: - T.bind(ii_init, i1) + with T.block("B_init"): + ii_init = T.axis.S(4, i1) B[io, ii_init] = 0.0 for i1_1, i2_1 in T.grid(4, 8): - with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]: - T.bind(ii, i1_1) - T.bind(k, ko * 8 + i2_1) + with T.block("B"): + ii = T.axis.S(4, i1_1) + k = T.axis.R(128, ko * 8 + i2_1) B[io, ii] = B[io, ii] + A[io, ii, k] @@ -52,11 +51,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -65,11 +65,15 @@ def matmul_decompose0(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([128, 128], "init") as [vi, vj]: - C[vi, vj] = 0.0 + for i, j in T.grid(128, 128): + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = 0.0 - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -78,16 +82,19 @@ def matmul_decompose1(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [32, 4], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 32): - with T.block([32], "blockized_B_init") as [io]: + with T.block("blockized_B_init"): + io = T.axis.S(32, i0) for i1 in T.serial(0, 4): - with T.block([4], "B_init") as [ii]: + with T.block("B_init"): + ii = T.axis.S(4, i1) B[io, ii] = T.float32(0) for i0, i2_o in T.grid(32, 16): - with T.block([32, T.reduce_axis(0, 16)], "blockized_B_update") as [io, ko]: + with T.block("blockized_B_update"): + io, ko = T.axis.remap("SR", [i0, i2_o]) for i1, i2_i in T.grid(4, 8): - with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]: - T.bind(ii, i1) - T.bind(k, ((ko * 8) + i2_i)) + with T.block("B"): + ii = T.axis.S(4, i1) + k = T.axis.R(128, ko * 8 + i2_i) B[io, ii] = B[io, ii] + A[io, ii, k] @@ -98,10 +105,12 @@ def matmul_decompose2(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(128, 128): - with T.block([128, 128], "update_init") as [vi_init, vj_init]: + with T.block("update_init"): + vi_init, vj_init = T.axis.remap("SS", [i0, i1]) C[vi_init, vj_init] = T.float32(0) for i2 in T.serial(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [vi, vj, vk]: + with T.block("update_update"): + vi, vj, vk = T.axis.remap("SSR", [i0, i1, i2]) C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) @@ -112,12 +121,10 @@ def matmul_decompose_fail3(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, k, j in T.grid(128, 128, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -127,25 +134,21 @@ def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) for i0_0 in T.serial(0, 16): for i0_1_init, i1_init in T.grid(8, 128): - with T.block([128, 128], "update_init") as [vi_init, vj_init]: - T.bind(vi_init, ((i0_0 * 8) + i0_1_init)) - T.bind(vj_init, i1_init) + with T.block("update_init"): + vi_init = T.axis.S(128, i0_0 * 8 + i0_1_init) + vj_init = T.axis.S(128, i1_init) C[vi_init, vj_init] = T.float32(0) for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7): - with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [ - vi, - vj, - vk, - ]: + with T.block("update_update"): T.where((((i2_0 * 7) + i2_1) < 128)) - T.bind(vi, ((i0_0 * 8) + i0_1)) - T.bind(vj, i1) - T.bind(vk, ((i2_0 * 7) + i2_1)) + vi = T.axis.S(128, i0_0 * 8 + i0_1) + vj = T.axis.S(128, i1) + vk = T.axis.R(128, i2_0 * 7 + i2_1) C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index a60ab8dca972..8267a369cf5d 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -30,8 +30,10 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + for i, j, k, l in T.grid(128, 128, 128, 128): + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @T.prim_func @@ -39,11 +41,9 @@ def elementwise_not_affine(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 8): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) - T.bind(vl, l * 16) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + vl = T.axis.S(128, l * 16) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -53,7 +53,8 @@ def elementwise_dependent_loop(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128, 128)) for i in T.serial(0, 128): for j, k, l in T.grid(128, i, 128): - with T.block([128, 128, i, 128], "B") as [vi, vj, vk, vl]: + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -62,8 +63,9 @@ def elementwise_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 128): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + with T.block("B"): T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -74,16 +76,12 @@ def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block([128, 128, 128], "C") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in T.serial(0, 128): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = C[vi, vj, vk] * 2.0 @@ -92,12 +90,11 @@ def elementwise_with_loops_not_same_scope(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "A") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) for k in T.serial(0, 128): - with T.block([128], "B") as [vk]: - T.bind(vk, k) + with T.block("B"): + vk = T.axis.S(128, k) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -108,10 +105,9 @@ def elementwise_with_wrong_block_var_type(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block([128, 128, T.scan_axis(0, 128)], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + vk = T.axis.scan(128, k) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -122,11 +118,8 @@ def elementwise_reordered(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for l, j, k, i in T.grid(128, 128, 128, 128): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) - T.bind(vl, l) + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -135,11 +128,8 @@ def elementwise_reordered2(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for k, j, i, l in T.grid(128, 128, 128, 128): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) - T.bind(vl, l) + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -148,12 +138,9 @@ def elementwise_reordered_with_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for l, j, k, i in T.grid(128, 128, 128, 128): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + with T.block("B"): T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) - T.bind(vl, l) + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -161,14 +148,18 @@ def elementwise_reordered_with_predicate(a: T.handle, b: T.handle) -> None: def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") - with T.block([16, 16], "A") as [vi, vj]: - T.reads([]) - T.writes([A[0:16, 0:16]]) - T.store(A.data, vi * 16 + vj, 1) - with T.block([16, 16], "B") as [vi, vj]: - T.reads([]) - T.writes([B[0:16, 0:16]]) - T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + for i, j in T.grid(16, 16): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([A[0:16, 0:16]]) + T.store(A.data, vi * 16 + vj, 1) + for i, j in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([B[0:16, 0:16]]) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) @T.prim_func @@ -176,16 +167,14 @@ def opaque_access_reorder(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") for j, i in T.grid(16, 16): - with T.block([16, 16], "A") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) T.store(A.data, vi * 16 + vj, 1) for j, i in T.grid(16, 16): - with T.block([16, 16], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index 78b6a4696baa..bd474ed34295 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -34,10 +34,9 @@ def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - T.bind(vi, i0) - T.bind(vj, i1) - T.bind(vk, (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner)) + with T.block("update"): + vi, vj = T.axis.remap("SS", [i0, i1]) + vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) T.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) T.writes([C[vi, vj]]) with T.init(): @@ -53,18 +52,12 @@ def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: C_rf = T.alloc_buffer([4, 128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block([4, 128, 128, T.reduce_axis(0, 4), T.reduce_axis(0, 8)], "update_rf") as [ - vi2_inner_inner, - vi, - vj, - vi2_outer, - vi2_inner_outer, - ]: - T.bind(vi2_inner_inner, i2_inner_inner) - T.bind(vi, i0) - T.bind(vj, i1) - T.bind(vi2_outer, i2_outer) - T.bind(vi2_inner_outer, i2_inner_outer) + with T.block("update_rf"): + vi2_inner_inner = T.axis.S(4, i2_inner_inner) + vi = T.axis.S(128, i0) + vj = T.axis.S(128, i1) + vi2_outer = T.axis.R(4, i2_outer) + vi2_inner_outer = T.axis.R(8, i2_inner_outer) with T.init(): C_rf[vi2_inner_inner, vi, vj] = 0.0 C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + ( @@ -73,14 +66,8 @@ def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: ) for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4): - with T.block([T.reduce_axis(0, 4), 128, 128], "update") as [ - vi2_inner_inner_1, - vi_1, - vj_1, - ]: - T.bind(vi2_inner_inner_1, i2_inner_inner_1) - T.bind(vi_1, i0_1) - T.bind(vj_1, i1_1) + with T.block("update"): + vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", [i2_inner_inner_1, i0_1, i1_1]) with T.init(): C[vi_1, vj_1] = 0.0 C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1] @@ -93,13 +80,17 @@ def matmul_not_stage_pipeline(a: T.handle, b: T.handle, d: T.handle) -> None: D = T.match_buffer(d, [256, 256]) C = T.alloc_buffer([256, 256]) - with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([256, 256], "D") as [vi, vj]: - D[vi, vj] = C[vi, vj] + for i, j in T.grid(256, 256): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = C[vi, vj] @T.prim_func @@ -108,10 +99,12 @@ def matmul_not_same_buffer_access(a: T.handle, b: T.handle, c: T.handle) -> None B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj] @T.prim_func @@ -122,17 +115,13 @@ def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.ha D = T.match_buffer(d, [128, 128]) for k, i, j in T.grid(128, 128, 128): - with T.block([T.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]: - T.bind(ck, k) - T.bind(ci, i) - T.bind(cj, j) + with T.block("C"): + ck, ci, cj = T.axis.remap("RSS", [k, i, j]) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] - with T.block([T.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]: - T.bind(dk, k) - T.bind(di, i) - T.bind(dj, j) + with T.block("D"): + dk, di, dj = T.axis.remap("RSS", [k, i, j]) with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj] @@ -143,10 +132,12 @@ def square_sum(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) C = T.match_buffer(c, [16]) - with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: - with T.init(): - C[b] = 0.0 - C[b] = C[b] + A[b, i, j] * A[b, i, j] + for b0, i0, j0 in T.grid(16, 256, 256): + with T.block("C"): + b, i, j = T.axis.remap("SRR", [b0, i0, j0]) + with T.init(): + C[b] = 0.0 + C[b] = C[b] + A[b, i, j] * A[b, i, j] @T.prim_func @@ -156,18 +147,15 @@ def square_sum_rfactor(a: T.handle, c: T.handle) -> None: C_rf = T.alloc_buffer([16, 256]) for i0, i1, i2 in T.grid(16, 256, 256): - with T.block([256, 16, T.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: - T.bind(vi2, i2) - T.bind(b, i0) - T.bind(i, i1) + with T.block("C_rf"): + vi2, b, i = T.axis.remap("SSR", [i2, i0, i1]) with T.init(): C_rf[b, vi2] = 0.0 C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) for i0_1, i2_1 in T.grid(16, 256): - with T.block([T.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: - T.bind(vi2_1, i2_1) - T.bind(b_1, i0_1) + with T.block("C"): + vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[b_1, vi2_1] @@ -180,18 +168,18 @@ def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: - T.bind(b, i0) - T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) - T.bind(j, T.floormod(i1_i2_fused_outer, 256)) + with T.block("C"): + b = T.axis.S(16, i0) + i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) + j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) T.reads([C[b], A[b, i, j]]) T.writes([C[b]]) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): - with T.block([16], "D") as [b_1]: - T.bind(b_1, i0_1) + with T.block("D"): + b_1 = T.axis.S(16, i0_1) T.reads([C[b_1]]) T.writes([D[b_1]]) D[b_1] = T.sqrt(C[b_1], dtype="float32") @@ -205,31 +193,24 @@ def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: C_rf = T.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block([1, 16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C_rf") as [ - vi1_i2_fused_inner, - b, - i, - j, - ]: - T.bind(vi1_i2_fused_inner, i1_i2_fused_inner) - T.bind(b, i0) - T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) - T.bind(j, T.floormod(i1_i2_fused_outer, 256)) + with T.block("C_rf"): + vi1_i2_fused_inner, b = T.axis.remap("SS", [i1_i2_fused_inner, i0]) + i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) + j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) with T.init(): C_rf[vi1_i2_fused_inner, b] = 0.0 C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): - with T.block([T.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: - T.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) - T.bind(b_1, i0_1) + with T.block("C"): + vi1_i2_fused_inner_1, b_1 = T.axis.remap("RS", [i1_i2_fused_inner_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] for i0_2 in T.serial(0, 16): - with T.block([16], "D") as [b_2]: - T.bind(b_2, i0_2) + with T.block("D"): + b_2 = T.axis.S(16, i0_2) D[b_2] = T.sqrt(C[b_2], dtype="float32") @@ -238,8 +219,10 @@ def element_wise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -247,10 +230,12 @@ def rowsum(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] @T.prim_func @@ -259,9 +244,9 @@ def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 16): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, T.floordiv(k * k, 2)) + with T.block("B"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, T.floordiv(k * k, 2)) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -272,10 +257,12 @@ def rowsum_not_dominant(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi, vk] = 0.0 - B[vi, vk] = B[vi, vk] + A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi, vk] = 0.0 + B[vi, vk] = B[vi, vk] + A[vi, vk] @T.prim_func @@ -285,9 +272,8 @@ def rowsum_not_serial(a: T.handle, b: T.handle) -> None: for i in T.serial(0, 128): for k in T.parallel(0, 128): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, k) + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -298,10 +284,12 @@ def rowsum_wrong_reduce_pattern1(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 1.0 - B[vi] = B[vi] + A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 1.0 + B[vi] = B[vi] + A[vi, vk] @T.prim_func @@ -309,10 +297,12 @@ def rowsum_wrong_reduce_pattern2(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] - A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] - A[vi, vk] @T.prim_func @@ -321,9 +311,9 @@ def rowsum_transformed(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for io, ii_ko_fused, ki in T.grid(32, 128, 4): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, io * 4 + T.floordiv(ii_ko_fused, 32)) - T.bind(vk, T.floormod(ii_ko_fused, 32) * 4 + ki) + with T.block("B"): + vi = T.axis.S(128, io * 4 + T.floordiv(ii_ko_fused, 32)) + vk = T.axis.R(128, T.floormod(ii_ko_fused, 32) * 4 + ki) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -334,10 +324,12 @@ def rowsum_zero_dim(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128]) B = T.match_buffer(b, []) - with T.block([T.reduce_axis(0, 128)], "B") as [k]: - with T.init(): - B[()] = 0.0 - B[()] = B[()] + A[k] + for k0 in range(128): + with T.block("B"): + k = T.axis.R(128, k0) + with T.init(): + B[()] = 0.0 + B[()] = B[()] + A[k] @T.prim_func @@ -346,15 +338,19 @@ def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, []) B_rf = T.alloc_buffer([128]) - with T.block([128], "B_rf") as [vi0]: - with T.init(): - B_rf[vi0] = 0.0 - B_rf[vi0] = B_rf[vi0] + A[vi0] + for i in range(128): + with T.block("B_rf"): + vi0 = T.axis.S(128, i) + with T.init(): + B_rf[vi0] = 0.0 + B_rf[vi0] = B_rf[vi0] + A[vi0] - with T.block([T.reduce_axis(0, 128)], "B") as [vi0_1]: - with T.init(): - B[()] = 0.0 - B[()] = B[()] + B_rf[vi0_1] + for i in range(128): + with T.block("B"): + vi0_1 = T.axis.R(128, i) + with T.init(): + B[()] = 0.0 + B[()] = B[()] + B_rf[vi0_1] @T.prim_func @@ -362,10 +358,10 @@ def rowsum_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.block("B"): T.where(k_0 * 10 + k_1 < 128) - T.bind(vi, i) - T.bind(vk, k_0 * 10 + k_1) + vi = T.axis.S(128, i) + vk = T.axis.R(128, k_0 * 10 + k_1) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -377,18 +373,15 @@ def rowsum_predicate_rfactor(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128], dtype="float32") B_rf = T.alloc_buffer([128, 13], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): - with T.block([13, 128, T.reduce_axis(0, 10)], "B_rf") as [vk_0, vi, vk_1]: + with T.block("B_rf"): + vk_0, vi, vk_1 = T.axis.remap("SSR", [k_0, i, k_1]) T.where(k_0 * 10 + k_1 < 128) - T.bind(vk_0, k_0) - T.bind(vi, i) - T.bind(vk_1, k_1) with T.init(): B_rf[vi, vk_0] = T.float32(0) B_rf[vi, vk_0] = B_rf[vi, vk_0] + A[vi, vk_0 * 10 + vk_1] for i, k_0 in T.grid(128, 13): - with T.block([T.reduce_axis(0, 13), 128], "B") as [vk_0, vi]: - T.bind(vk_0, k_0) - T.bind(vi, i) + with T.block("B"): + vk_0, vi = T.axis.remap("RS", [k_0, i]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + B_rf[vi, vk_0] @@ -405,35 +398,31 @@ def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None: for i in T.serial(0, 16): for j1 in T.serial(0, 16): for k1o, k1i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "C") as [ci, cj, ck]: - T.bind(ci, i) - T.bind(cj, j1) - T.bind(ck, k1o * 4 + k1i) + with T.block("C"): + ci, cj = T.axis.remap("SS", [i, j1]) + ck = T.axis.R(16, k1o * 4 + k1i) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, cj, ck] for k2o, k2i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: - T.bind(di, i) - T.bind(dj, j1) - T.bind(dk, k2o * 4 + k2i) + with T.block("D"): + di, dj = T.axis.remap("SS", [i, j1]) + dk = T.axis.R(16, k2o * 4 + k2i) with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] for j2 in T.serial(0, 16): for k3o, k3i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: - T.bind(ei, i) - T.bind(ej, j2) - T.bind(ek, k3o * 4 + k3i) + with T.block("E"): + ei, ej = T.axis.remap("SS", [i, j2]) + ek = T.axis.R(16, k3o * 4 + k3i) with T.init(): E[ei, ej] = 0.0 E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] for k4o, k4i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: - T.bind(fi, i) - T.bind(fj, j2) - T.bind(fk, k4o * 4 + k4i) + with T.block("F"): + fi, fj = T.axis.remap("SS", [i, j2]) + fk = T.axis.R(16, k4o * 4 + k4i) with T.init(): F[fi, fj] = 0.0 F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj] @@ -449,46 +438,38 @@ def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: C_rf = T.alloc_buffer([16, 16, 4]) for i, j1, k1o, k1i in T.grid(16, 16, 4, 4): - with T.block([4, 16, 16, T.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: - T.bind(vk1o, k1o) - T.bind(ci, i) - T.bind(cj, j1) - T.bind(vk1i, k1i) + with T.block("C_rf"): + vk1o, ci, cj, vk1i = T.axis.remap("SSSR", [k1o, i, j1, k1i]) with T.init(): C_rf[ci, cj, vk1o] = 0.0 C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] for i_1 in T.serial(0, 16): for j1_1 in T.serial(0, 16): for k1o_1 in T.serial(0, 4): - with T.block([T.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: - T.bind(vk1o_1, k1o_1) - T.bind(ci_1, i_1) - T.bind(cj_1, j1_1) + with T.block("C"): + vk1o_1, ci_1, cj_1 = T.axis.remap("RSS", [k1o_1, i_1, j1_1]) with T.init(): C[ci_1, cj_1] = 0.0 C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] for k2o, k2i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: - T.bind(di, i_1) - T.bind(dj, j1_1) - T.bind(dk, (k2o * 4) + k2i) + with T.block("D"): + di, dj = T.axis.remap("SS", [i_1, j1_1]) + dk = T.axis.R(16, k2o * 4 + k2i) with T.init(): D[di, dj] = 0.0 D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] for j2 in T.serial(0, 16): for k3o, k3i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: - T.bind(ei, i_1) - T.bind(ej, j2) - T.bind(ek, (k3o * 4) + k3i) + with T.block("E"): + ei, ej = T.axis.remap("SS", [i_1, j2]) + ek = T.axis.R(16, k3o * 4 + k3i) with T.init(): E[ei, ej] = 0.0 E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] for k4o, k4i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: - T.bind(fi, i_1) - T.bind(fj, j2) - T.bind(fk, (k4o * 4) + k4i) + with T.block("F"): + fi, fj = T.axis.remap("SS", [i_1, j2]) + fk = T.axis.R(16, k4o * 4 + k4i) with T.init(): F[fi, fj] = 0.0 F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj] diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index c93c7ca63aa8..fbf0a6a5bd78 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -32,8 +32,10 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for i, j, k in T.grid(128, 128, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 29cfe8cadfb3..d2365c39c9cb 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -30,8 +30,10 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for i, j, k in T.grid(128, 128, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @T.prim_func @@ -40,7 +42,10 @@ def elementwise_dependent_loops(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i in T.serial(0, 128): for j, k in T.grid(i, 128): - with T.block([128, i, 128], "B") as [vi, vj, vk]: + with T.block("B"): + vi = T.axis.S(128, i) + vj = T.axis.S(i, j) + vk = T.axis.S(128, k) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -49,7 +54,8 @@ def elementwise_symbolic(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i, j, k in T.grid(128, 128, n): - with T.block([128, 128, n], "B") as [vi, vj, vk]: + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -58,10 +64,10 @@ def elementwise_symbolic_fused(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i_j_k_fused in T.serial(0, (n * 16384)): - with T.block([128, 128, n], "B") as [vi, vj, vk]: - T.bind(vi, T.floordiv(i_j_k_fused, (n * 128))) - T.bind(vj, T.floormod(T.floordiv(i_j_k_fused, n), 128)) - T.bind(vk, T.floormod(i_j_k_fused, n)) + with T.block("B"): + vi = T.axis.S(128, T.floordiv(i_j_k_fused, n * 128)) + vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, n), 128)) + vk = T.axis.S(n, T.floormod(i_j_k_fused, n)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -72,11 +78,10 @@ def elementwise_symbolic_split(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i, j, k0, k1 in T.grid(128, 128, 10, T.floordiv((n + 9), 10)): - with T.block([128, 128, n], "B") as [vi, vj, vk]: + with T.block("B"): T.where((((k0 * T.floordiv((n + 9), 10)) + k1) < n)) - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, ((k0 * T.floordiv((n + 9), 10)) + k1)) + vi, vj = T.axis.remap("SS", [i, j]) + vk = T.axis.S(n, k0 * T.floordiv(n + 9, 10) + k1) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -89,10 +94,12 @@ def elementwise_with_seq(a: T.handle, b: T.handle) -> None: C = T.alloc_buffer((128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block([128, 128, 128], "C") as [vi, vj, vk]: + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in T.serial(0, 128): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = C[vi, vj, vk] * 2.0 @@ -102,10 +109,8 @@ def elementwise_with_anno(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128, annotations={"useless_annotation": True}): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -117,10 +122,8 @@ def elementwise_with_thread_binding(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -132,10 +135,8 @@ def elementwise_with_starting_point(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(10, 128): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -146,13 +147,11 @@ def elementwise_with_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block([], "opaque"): + with T.block("opaque"): T.reads([A[i, j, k]]) T.writes([B[i, j, k]]) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -163,10 +162,10 @@ def elementwise_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for fused in T.serial(0, 2097152): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, T.floordiv(fused, 16384)) - T.bind(vj, T.floormod(T.floordiv(fused, 128), 128)) - T.bind(vk, T.floormod(fused, 128)) + with T.block("B"): + vi = T.axis.S(128, T.floordiv(fused, 16384)) + vj = T.axis.S(128, T.floormod(T.floordiv(fused, 128), 128)) + vk = T.axis.S(128, T.floormod(fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -177,10 +176,10 @@ def elementwise_split_case0(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, k1, k2 in T.grid(2, 1, 64, 4, 32, 16, 8): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, ((i1 * 64) + i3)) - T.bind(vj, ((j1 * 32) + j2)) - T.bind(vk, ((k1 * 8) + k2)) + with T.block("B"): + vi = T.axis.S(128, i1 * 64 + i3) + vj = T.axis.S(128, j1 * 32 + j2) + vk = T.axis.S(128, k1 * 8 + k2) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -191,10 +190,10 @@ def elementwise_split_case1(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, j3, k1, k2, k3 in T.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i1 * 64 + i3) - T.bind(vj, j1 * 64 + j3) - T.bind(vk, k1 * 64 + k3) + with T.block("B"): + vi = T.axis.S(128, i1 * 64 + i3) + vj = T.axis.S(128, j1 * 64 + j3) + vk = T.axis.S(128, k1 * 64 + k3) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -205,16 +204,11 @@ def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) A = T.match_buffer(a, [128, 128, 128]) for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.where( - ( - ((((((i0 * 2) + i1) * 3) + i2) < 128) and (((j0 * 129) + j1) < 128)) - and (((k0 * 43) + k1) < 128) - ) - ) - T.bind(vi, (((i0 * 6) + (i1 * 3)) + i2)) - T.bind(vj, j1) - T.bind(vk, ((k0 * 43) + k1)) + with T.block("B"): + T.where((i0 * 2 + i1) * 3 + i2 < 128 and j0 * 129 + j1 < 128 and k0 * 43 + k1 < 128) + vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2) + vj = T.axis.S(128, j1) + vk = T.axis.S(128, k0 * 43 + k1) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -225,7 +219,7 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) A = T.match_buffer(a, [128, 128, 128]) for i_j_k_fused in T.serial(0, 2097152): - with T.block([], "opaque"): + with T.block("opaque"): T.reads( [ A[ @@ -244,10 +238,10 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: ] ] ) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, T.floordiv(i_j_k_fused, 16384)) - T.bind(vj, T.floormod(T.floordiv(i_j_k_fused, 128), 128)) - T.bind(vk, T.floormod(i_j_k_fused, 128)) + with T.block("B"): + vi = T.axis.S(128, T.floordiv(i_j_k_fused, 16384)) + vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, 128), 128)) + vk = T.axis.S(128, T.floormod(i_j_k_fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -259,13 +253,12 @@ def elementwise_split_with_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) for i0, i1, j, k in T.grid(8, 16, 128, 128): - with T.block([], "opaque"): + with T.block("opaque"): T.reads([A[i0 * 16 + i1, j, k]]) T.writes([B[i0 * 16 + i1, j, k]]) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i0 * 16 + i1) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi = T.axis.S(128, i0 * 16 + i1) + vj, vk = T.axis.remap("SS", [j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -275,14 +268,18 @@ def elementwise_split_with_opaque_block(a: T.handle, b: T.handle) -> None: def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") - with T.block([16, 16], "A") as [vi, vj]: - T.reads([]) - T.writes([A[0:16, 0:16]]) - T.store(A.data, vi * 16 + vj, 1) - with T.block([16, 16], "B") as [vi, vj]: - T.reads([]) - T.writes([B[0:16, 0:16]]) - T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + for i, j in T.grid(16, 16): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([A[0:16, 0:16]]) + T.store(A.data, vi * 16 + vj, 1) + for i, j in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([B[0:16, 0:16]]) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) @T.prim_func @@ -290,16 +287,16 @@ def opaque_access_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16]) B = T.match_buffer(b, [16, 16]) for i_j_fused in T.serial(0, 256): - with T.block([16, 16], "A") as [vi, vj]: - T.bind(vi, T.floordiv(i_j_fused, 16)) - T.bind(vj, T.floormod(i_j_fused, 16)) + with T.block("A"): + vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) + vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([A[0:16, 0:16]]) T.store(A.data, ((vi * 16) + vj), 1, 1) for i_j_fused in T.serial(0, 256): - with T.block([16, 16], "B") as [vi, vj]: - T.bind(vi, T.floordiv(i_j_fused, 16)) - T.bind(vj, T.floormod(i_j_fused, 16)) + with T.block("B"): + vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) + vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle")) @@ -310,16 +307,16 @@ def opaque_access_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16)) B = T.match_buffer(b, (16, 16)) for i, j0, j1 in T.grid(16, 4, 4): - with T.block([16, 16], "A") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, ((j0 * 4) + j1)) + with T.block("A"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) T.writes([A[0:16, 0:16]]) T.store(A.data, ((vi * 16) + vj), 1, 1) for i, j0, j1 in T.grid(16, 4, 4): - with T.block([16, 16], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, ((j0 * 4) + j1)) + with T.block("B"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle")) @@ -331,9 +328,9 @@ def elementwise_not_affine(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (127, 128)) for i in T.serial(0, 4): for j, k in T.grid(T.min(31, 126 - i * 32) + 1, 128): - with T.block([127, 128], "B") as [vi, vj]: - T.bind(vi, i * 32 + j) - T.bind(vj, k) + with T.block("B"): + vi = T.axis.S(127, i * 32 + j) + vj = T.axis.S(128, k) B[vi, vj] = A[vi, vj] @@ -343,12 +340,12 @@ def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [127, 128]) for i in T.grid(4): for j_k_fused in T.serial(0, T.min(31, 126 - i * 32) * 128 + 128): - with T.block([127, 128], "B") as [vi, vj]: - T.bind( - vi, + with T.block("B"): + vi = T.axis.S( + 127, i * 32 + T.floormod(T.floordiv(j_k_fused, 128), T.min(31, 126 - i * 32) + 1), ) - T.bind(vj, T.floormod(j_k_fused, 128)) + vj = T.axis.S(128, T.floormod(j_k_fused, 128)) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = A[vi, vj] diff --git a/tests/python/unittest/test_tir_schedule_state.py b/tests/python/unittest/test_tir_schedule_state.py index 94e1b4a6b395..bc62fa1ba950 100644 --- a/tests/python/unittest/test_tir_schedule_state.py +++ b/tests/python/unittest/test_tir_schedule_state.py @@ -32,10 +32,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -44,10 +48,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -55,22 +61,28 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: def block_in_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") - with T.block([128], "B") as vi: - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - B[vi, 0] = A[vi, 0] - if A[vi, 0] == 0.0: - with T.block([], "C"): - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - with T.block([128], "D") as vj: - B[vi, vj] = A[vi, vj] * 3.0 - else: - with T.block([], "E"): - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - with T.block([128], "F") as vj: - B[vi, vj] = A[vi, vj] * 2.0 + for i in range(128): + with T.block("B"): + vi = T.axis.S(128, i) + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + B[vi, 0] = A[vi, 0] + if A[vi, 0] == 0.0: + with T.block("C"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + for j in range(128): + with T.block("D"): + vj = T.axis.S(128, j) + B[vi, vj] = A[vi, vj] * 3.0 + else: + with T.block("E"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + for j in range(128): + with T.block("F"): + vj = T.axis.S(128, j) + B[vi, vj] = A[vi, vj] * 2.0 # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index e2b39ce7c289..e3bd000c2e70 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -32,10 +32,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -44,10 +48,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = 0.0 for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -55,22 +61,28 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: def block_in_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") - with T.block([128], "B") as vi: - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - B[vi, 0] = A[vi, 0] - if A[vi, 0] == 0.0: - with T.block([], "C"): - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - with T.block([128], "D") as vj: - B[vi, vj] = A[vi, vj] * 3.0 - else: - with T.block([], "E"): - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - with T.block([128], "F") as vj: - B[vi, vj] = A[vi, vj] * 2.0 + for i in range(128): + with T.block("B"): + vi = T.axis.S(128, i) + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + B[vi, 0] = A[vi, 0] + if A[vi, 0] == 0.0: + with T.block("C"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + for j in range(128): + with T.block("D"): + vj = T.axis.S(128, j) + B[vi, vj] = A[vi, vj] * 3.0 + else: + with T.block("E"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + for j in range(128): + with T.block("F"): + vj = T.axis.S(128, j) + B[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -78,10 +90,14 @@ def write_after_read(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -90,9 +106,11 @@ def loop_carried_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128,)) C = T.match_buffer(c, (128,)) for i in range(0, 128): - with T.block([128], "B") as vi: + with T.block("B"): + vi = T.axis.S(128, i) B[vi] = A[vi] * 2.0 - with T.block([128], "C") as vi: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = T.if_then_else(vi >= 1, B[vi - 1] + 1.0, 0.0, dtype="float32") @@ -101,14 +119,17 @@ def concatenate_multi_producer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) for i in range(0, 64): - with T.block([64], "A_0") as vi: + with T.block("A_0"): + vi = T.axis.S(64, i) A[vi] = vi + 1 for i in range(0, 64): - with T.block([64], "A_1") as vi: - T.bind(vi, i + 64) + with T.block("A_1"): + vi = T.axis.S(64, i + 64) A[vi] = vi + 2 - with T.block([128], "B") as vi: - B[vi] = A[vi] * 2.0 + for i in range(0, 128): + with T.block("B"): + vi = T.axis.S(128, i) + B[vi] = A[vi] * 2.0 @T.prim_func @@ -116,14 +137,17 @@ def concatenate_multi_producer_uncovered(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) for i in range(0, 63): - with T.block([63], "A_0") as vi: + with T.block("A_0"): + vi = T.axis.S(63, i) A[vi] = vi + 1 for i in range(0, 64): - with T.block([64], "A_1") as vi: - T.bind(vi, i + 64) + with T.block("A_1"): + vi = T.axis.S(64, i + 64) A[vi] = vi + 2 - with T.block([128], "B") as vi: - B[vi] = A[vi] * 2.0 + for i in range(0, 128): + with T.block("B"): + vi = T.axis.S(128, i) + B[vi] = A[vi] * 2.0 @T.prim_func @@ -132,9 +156,11 @@ def lca_at_loop(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128,)) C = T.match_buffer(c, (128,)) for i in range(0, 128): - with T.block([128], "B") as vi: + with T.block("B"): + vi = T.axis.S(128, i) B[vi] = A[vi] * 2.0 - with T.block([128], "C") as vi: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = B[vi] + 1.0 @@ -143,18 +169,20 @@ def multi_producer_consumer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) for i in range(0, 64): - with T.block([64], "A_0") as vi: + with T.block("A_0"): + vi = T.axis.S(64, i) A[vi] = vi + 1 for i in range(0, 64): - with T.block([64], "A_1") as vi: - T.bind(vi, i + 64) + with T.block("A_1"): + vi = T.axis.S(64, i + 64) A[vi] = vi + 2 for i in range(0, 64): - with T.block([64], "B_0") as vi: + with T.block("B_0"): + vi = T.axis.S(64, i) B[vi] = A[vi] + 2.0 for i in range(0, 64): - with T.block([64], "B_1") as vi: - T.bind(vi, i + 64) + with T.block("B_1"): + vi = T.axis.S(64, i + 64) B[vi] = A[vi] + 3.0 @@ -164,12 +192,14 @@ def elementwise_affine_producer(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") for i, j, k, l in T.grid(16, 2, 32, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i * 8 + j * 4 + k // 8) - T.bind(vj, k % 8 * 16 + l) + with T.block("B"): + vi = T.axis.S(128, i * 8 + j * 4 + k // 8) + vj = T.axis.S(128, k % 8 * 16 + l) B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -177,13 +207,19 @@ def elementwise_subblock(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([32, 32], "B") as [vi, vj]: - T.reads([A[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) - T.writes([B[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) - with T.block([4, 4], "B_sub") as [vi_i, vj_i]: - B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(32, 32): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([A[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) + T.writes([B[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) + for ii, jj in T.grid(4, 4): + with T.block("B_sub"): + vi_i, vj_i = T.axis.remap("SS", [ii, jj]) + B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -191,13 +227,19 @@ def elementwise_subblock_uncovered(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([32, 32], "B") as [vi, vj]: - T.reads([A[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) - T.writes([B[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) - with T.block([2, 2], "B_sub") as [vi_i, vj_i]: - B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(32, 32): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([A[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) + T.writes([B[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) + for ii, jj in T.grid(2, 2): + with T.block("B_sub"): + vi_i, vj_i = T.axis.remap("SS", [ii, jj]) + B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -207,10 +249,12 @@ def bound_to_thread(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], scope="shared") for i in T.thread_binding(0, 128, thread="threadIdx.x"): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vj, vi] = B[vj, vi] + 1.0 @@ -222,14 +266,14 @@ def equal_ranked_threads(a: T.handle, c: T.handle) -> None: for i_o in T.thread_binding(0, 16, thread="threadIdx.x"): for i_i in T.thread_binding(0, 8, thread="threadIdx.y"): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i_o * 8 + i_i) - T.bind(vj, j) + with T.block("B"): + vi = T.axis.S(128, i_o * 8 + i_i) + vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_o * 8 + i_i) - T.bind(vj, j) + with T.block("C"): + vi = T.axis.S(128, i_o * 8 + i_i) + vj = T.axis.S(128, j) C[vj, vi] = B[vj, vi] + 1.0 @@ -241,10 +285,12 @@ def warp_memory(a: T.handle, c: T.handle) -> None: for i_o in T.thread_binding(0, 4, thread="threadIdx.y"): for i_i in T.thread_binding(0, 32, thread="threadIdx.x"): for j in T.serial(0, 128): - with T.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: + with T.block("B"): + warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 for j in T.serial(0, 128): - with T.block([4, 32, 128], "C") as [warp_id, lane_id, vj]: + with T.block("C"): + warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 @@ -256,11 +302,15 @@ def warp_memory_negative(a: T.handle, c: T.handle) -> None: for i_o in T.thread_binding(0, 4, thread="threadIdx.y"): for i_i in T.thread_binding(0, 32, thread="threadIdx.x"): for j in T.serial(0, 128): - with T.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: + with T.block("B"): + warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 for i_o_prime in T.thread_binding(0, 4, thread="threadIdx.y"): for j in T.serial(0, 128): - with T.block([4, 32, 4, 128], "C") as [_warp_id, lane_id, warp_id, vj]: + with T.block("C"): + _warp_id, warp_id, lane_id, vj = T.axis.remap( + "SSSS", [i_o, i_i, i_o_prime, j] + ) C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py b/tests/python/unittest/test_tir_schedule_storage_align.py index 7d0e91f70e60..3b699fd8f1b2 100644 --- a/tests/python/unittest/test_tir_schedule_storage_align.py +++ b/tests/python/unittest/test_tir_schedule_storage_align.py @@ -29,22 +29,20 @@ def element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, ax1) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): - with T.block([128, 128], "C") as [vi_1, vj_1]: - T.bind(vi_1, i0) - T.bind(vj_1, i1) + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) @@ -55,23 +53,21 @@ def element_wise_storage_align(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, ax1) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) T.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]}) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): - with T.block([128, 128], "C") as [vi_1, vj_1]: - T.bind(vi_1, i0) - T.bind(vj_1, i1) + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) @@ -82,23 +78,21 @@ def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): T.block_attr({"buffer_dim_align": [0]}) - T.bind(vi, i0) - T.bind(vj, ax1) + vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): - with T.block([128, 128], "C") as [vi_1, vj_1]: - T.bind(vi_1, i0) - T.bind(vj_1, i1) + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index 36e05c6b5170..f1c97c57b2ff 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -32,18 +32,24 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func def elementwise_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 185d229b44e1..440d0ab67a50 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -34,10 +34,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] diff --git a/tests/python/unittest/test_tir_specialize.py b/tests/python/unittest/test_tir_specialize.py index 86dc5dffed9f..72666a89ebcb 100644 --- a/tests/python/unittest/test_tir_specialize.py +++ b/tests/python/unittest/test_tir_specialize.py @@ -27,10 +27,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle, n: T.int32) -> None: B = T.match_buffer(b, [m, n]) C = T.match_buffer(c, [m, m]) - with T.block([m, m, T.reduce_axis(0, n)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(m, m, n): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -39,10 +41,12 @@ def matmul_128(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -52,10 +56,12 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [m, 128]) C = T.match_buffer(c, [m, m]) - with T.block([m, m, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(m, m, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -66,10 +72,12 @@ def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [m, x * 8]) C = T.match_buffer(c, [m, m]) - with T.block([m, m, T.reduce_axis(0, x * 8)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(m, m, x * 8): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -81,11 +89,15 @@ def element_wise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((m, n), "float32") - with T.block([m, n], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(m, n): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 - with T.block([m, n], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(m, n): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -94,11 +106,15 @@ def element_wise_128_64(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 64), "float32") B = T.alloc_buffer((128, 64), "float32") - with T.block([128, 64], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 64): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 64], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 64): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -108,11 +124,15 @@ def element_wise_128_n(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, n), "float32") B = T.alloc_buffer((128, n), "float32") - with T.block([128, n], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, n): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, n], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, n): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -120,8 +140,10 @@ def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int32, q: T. A = T.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=q) B = T.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q) - with T.block([m, n], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(m, n): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] @T.prim_func @@ -129,8 +151,10 @@ def mem_copy_16_16_8_4(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32", strides=[8, 1], elem_offset=4) B = T.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4) - with T.block([16, 16], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] @T.prim_func @@ -138,8 +162,10 @@ def mem_copy_m_n_p_n(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int3 A = T.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=n) B = T.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n) - with T.block([m, n], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(m, n): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] @T.prim_func @@ -147,8 +173,10 @@ def param_in_arith_exprs(a: T.handle, b: T.handle) -> None: n = T.var("int32") A = T.match_buffer(a, [n // 8, 8], "int32") B = T.match_buffer(b, [n], "int32") - with T.block([n - 1], "") as [vi]: - B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 + for i in range(n - 1): + with T.block(): + vi = T.axis.S(n - 1, i) + B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 @T.prim_func @@ -156,8 +184,10 @@ def param_in_arith_exprs_n_16(a: T.handle, b: T.handle) -> None: n = T.var("int32") A = T.match_buffer(a, [2, 8], "int32") B = T.match_buffer(b, [16], "int32") - with T.block([15], "") as [vi]: - B[vi] = A[vi // 8, vi % 8] + 714 + for i in range(15): + with T.block(): + vi = T.axis.S(15, i) + B[vi] = A[vi // 8, vi % 8] + 714 def test_specialize_nothing(): diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 0cfc724e41de..7d3115428f5a 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -32,17 +32,17 @@ def elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] * 2.0 @@ -53,7 +53,7 @@ def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((1, 16), "float32") @@ -74,7 +74,7 @@ def unschedulable_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") @@ -89,11 +89,11 @@ def param_buffer_access_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (20, 20), "float32") B = T.match_buffer(c, (20, 20), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(B[i, 0:16]) for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 @@ -106,17 +106,17 @@ def shared_mem_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block([]): + with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((16, 16), "float32", scope="shared") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i0 * 8 + i1 * 4 + i2, j]) B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i0 * 8 + i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0 @@ -129,17 +129,17 @@ def compacted_shared_mem_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block([]): + with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((8, 16), "float32", scope="shared") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i1 * 4 + i2, j]) B[i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i1 * 4 + i2, j] * 2.0 @@ -152,17 +152,17 @@ def warp_mem_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block([]): + with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((16, 16), "float32", scope="warp") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i0 * 8 + i1 * 4 + i2, j]) B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i0 * 8 + i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0 @@ -175,17 +175,17 @@ def compacted_warp_mem_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block([]): + with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((4, 16), "float32", scope="warp") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i2, j]) B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0 @@ -196,17 +196,17 @@ def symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (n * 8,), "float32") C = T.match_buffer(c, (n * 8,), "float32") for i in range(0, n): - with T.block([]): + with T.block(): T.reads(A[i * 8 : i * 8 + 8]) T.writes(C[i * 8 : i * 8 + 8]) B = T.alloc_buffer((n * 8,), "float32") for j in range(0, 8): - with T.block([]) as []: + with T.block() as []: T.reads(A[i * 8 + j]) T.writes(B[i * 8 + j]) B[i * 8 + j] = A[i * 8 + j] + 1.0 for j in range(0, 8): - with T.block([]) as []: + with T.block() as []: T.reads(B[i * 8 + j]) T.writes(C[i * 8 + j]) C[i * 8 + j] = B[i * 8 + j] * 2.0 @@ -217,17 +217,17 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (n * 8,), "float32") C = T.match_buffer(c, (n * 8,), "float32") for i in range(0, n): - with T.block([]): + with T.block(): T.reads(A[i * 8 : i * 8 + 8]) T.writes(C[i * 8 : i * 8 + 8]) B = T.alloc_buffer((8,), "float32") for j in range(0, 8): - with T.block([]) as []: + with T.block() as []: T.reads(A[i * 8 + j]) T.writes(B[j]) B[j] = A[i * 8 + j] + 1.0 for j in range(0, 8): - with T.block([]) as []: + with T.block() as []: T.reads(B[j]) T.writes(C[i * 8 + j]) C[i * 8 + j] = B[j] * 2.0 @@ -238,12 +238,12 @@ def complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (8, 8), "float32") C = T.match_buffer(c, (8, 8), "float32") for i in range(0, 8): - with T.block([]): + with T.block(): T.reads(A[0, 8]) T.writes(C[0, 8]) B = T.alloc_buffer((8, 8), "float32") for j in range(0, 4): - with T.block([]) as []: + with T.block() as []: D = T.alloc_buffer((8, 8), "float32") T.reads(A[i, j]) T.writes(B[i, j]) @@ -252,12 +252,12 @@ def complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(2, 4): T.store(B.data, j, A[i, j] + D[k, j]) for j in range(3, 5): - with T.block([]) as []: + with T.block() as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] for j in range(6, 8): - with T.block([]) as []: + with T.block() as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] @@ -268,12 +268,12 @@ def compacted_complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (8, 8), "float32") C = T.match_buffer(c, (8, 8), "float32") for i in range(0, 8): - with T.block([]): + with T.block(): T.reads(A[0, 8]) T.writes(C[0, 8]) B = T.alloc_buffer((1, 8), "float32") for j in range(0, 4): - with T.block([]) as []: + with T.block() as []: D = T.alloc_buffer((6, 1), "float32") T.reads(A[i, j]) T.writes(B[0, j]) @@ -282,12 +282,12 @@ def compacted_complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(2, 4): T.store(B.data, j, A[i, j] + D[k - 2, 0]) for j in range(3, 5): - with T.block([]) as []: + with T.block() as []: T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] for j in range(6, 8): - with T.block([]) as []: + with T.block() as []: T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] @@ -298,19 +298,19 @@ def match_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16)) C = T.match_buffer(c, (16, 16)) for i in range(0, 16): - with T.block([]): + with T.block(): A0 = T.match_buffer(A[i, 0:16], (16)) C0 = T.match_buffer(C[i, 0:16], (16)) B = T.alloc_buffer((16, 16)) - with T.block([]): + with T.block(): B0 = T.match_buffer(B[i, 0:16], (16)) for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: A1 = T.match_buffer(A0[j], ()) B1 = T.match_buffer(B0[j], ()) B1[()] = A1[()] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: C1 = T.match_buffer(C0[j], ()) B2 = T.match_buffer(B[i, j], ()) C1[()] = B2[()] * 2.0 @@ -321,19 +321,19 @@ def compacted_match_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16)) C = T.match_buffer(c, (16, 16)) for i in range(0, 16): - with T.block([]): + with T.block(): A0 = T.match_buffer(A[i, 0:16], (16)) C0 = T.match_buffer(C[i, 0:16], (16)) B = T.alloc_buffer((1, 16)) - with T.block([]): + with T.block(): B0 = T.match_buffer(B[0, 0:16], (16)) for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: A1 = T.match_buffer(A0[j], ()) B1 = T.match_buffer(B0[j], ()) B1[()] = A1[()] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: C1 = T.match_buffer(C0[j], ()) B2 = T.match_buffer(B[0, j], ()) C1[()] = B2[()] * 2.0 @@ -344,18 +344,18 @@ def storage_align_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i, j]) T.writes(B[i, j]) T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] * 2.0 @@ -366,7 +366,7 @@ def compacted_storage_align_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32") diff --git a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py index 287a30916520..ee323a64c50f 100644 --- a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py @@ -32,19 +32,19 @@ def elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with T.block([16, 16]) as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block(): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j) B[vi, vj] = A[vi, vj] + 1.0 for j in range(0, 16): - with T.block([16, 16]) as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block(): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j) C[vi, vj] = B[vi, vj] * 2.0 @@ -53,7 +53,7 @@ def substituted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer([16, 16], "float32") diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 21c896c7bb7e..eed82ebb9118 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -32,7 +32,7 @@ def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer([1, 16], "float32", scope="global") @@ -67,7 +67,7 @@ def compacted_gpu_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 4, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="threadIdx.x"): for i2 in T.thread_binding(0, 2, thread="vthread"): - with T.block([]): + with T.block(): T.reads(A[i0 * 4 + i1 * 2 + i2, 0:16]) T.writes(C[i0 * 4 + i1 * 2 + i2, 0:16]) B = T.alloc_buffer([1, 16], "float32", scope="local") @@ -108,17 +108,17 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - with T.block([]): + with T.block(): T.reads(A[i, m]) T.writes(C[i, m]) B = T.alloc_buffer((m,), "float32", scope="global") for j in range(0, m): - with T.block([]) as []: + with T.block() as []: T.reads(A[i, j]) T.writes(B[j]) B[j] = A[i, j] + 1.0 for j in range(0, m): - with T.block([]) as []: + with T.block() as []: T.reads(B[j]) T.writes(C[i, j]) C[i, j] = B[j] * 2.0 @@ -143,7 +143,7 @@ def compacted_predicate_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (32), "float32") for i, j in T.grid(5, 7): - with T.block([]) as []: + with T.block() as []: T.reads(A[i * 7 + j]) T.writes(C[i * 7 + j]) T.where(i * 7 + j < 32) @@ -166,7 +166,7 @@ def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (32), "float32") for x, y, z in T.grid(4, 1, 8): - with T.block([]) as []: + with T.block() as []: T.reads(A[x * 8 + y * 8 + z]) T.writes(C[x * 8 + y * 8 + z]) C[x * 8 + y * 8 + z] = A[x * 8 + y * 8 + z] + 1.0 @@ -187,7 +187,7 @@ def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (32), "float32") for i in range(0, 32): - with T.block([]) as []: + with T.block() as []: T.reads(A[i]) T.writes(D[i]) B = T.alloc_buffer((32,), scope="global") @@ -215,7 +215,7 @@ def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i0 in range(0, 4): - with T.block([]): + with T.block(): T.reads(A[i0 * 4 : i0 * 4 + 4, 0:16]) T.writes(C[i0 * 4 : i0 * 4 + 4, 0:16]) B = T.alloc_buffer([4, 16], "float32", strides=[17, 1], scope="global") diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py index c1c4fb3d2e8f..a4fd9404eee4 100644 --- a/tests/python/unittest/test_tir_transform_lower_init_block.py +++ b/tests/python/unittest/test_tir_transform_lower_init_block.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir, te +from tvm import te from tvm.script import tir as T # pylint: disable=no-self-argument @@ -28,10 +28,13 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) - with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: - with T.init(): - B[i] = T.float32(0) - B[i] += A[i, j, k] + for i0, j0 in T.grid(64, 64): + for k0 in T.serial(32, 64): + with T.block(): + i, j, k = T.axis.remap("SRR", [i0, j0, k0]) + with T.init(): + B[i] = T.float32(0) + B[i] += A[i, j, k] @tvm.script.ir_module @@ -41,10 +44,13 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) - with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: - if (j == 0) and (k == 32): - B[i] = T.float32(0) - B[i] += A[i, j, k] + for i0, j0 in T.grid(64, 64): + for k0 in T.serial(32, 64): + with T.block(): + i, j, k = T.axis.remap("SRR", [i0, j0, k0]) + if (j == 0) and (k == 32): + B[i] = T.float32(0) + B[i] += A[i, j, k] @tvm.script.ir_module @@ -54,12 +60,15 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) - with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: - BB = T.match_buffer(B[i], ()) - AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) - with T.init(): - BB[()] = T.float32(0) - BB[()] += AA[j, k] + for i0, j0 in T.grid(64, 64): + for k0 in T.serial(32, 64): + with T.block(): + i, j, k = T.axis.remap("SRR", [i0, j0, k0]) + BB = T.match_buffer(B[i], ()) + AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) + with T.init(): + BB[()] = T.float32(0) + BB[()] += AA[j, k] @tvm.script.ir_module @@ -69,17 +78,21 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) - with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: - BB = T.match_buffer(B[i], ()) - AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) - if (j == 0) and (k == 32): - BB[()] = T.float32(0) - BB[()] += AA[j, k] + for i0, j0 in T.grid(64, 64): + for k0 in T.serial(32, 64): + with T.block(): + i, j, k = T.axis.remap("SRR", [i0, j0, k0]) + BB = T.match_buffer(B[i], ()) + AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) + if (j == 0) and (k == 32): + BB[()] = T.float32(0) + BB[()] += AA[j, k] def test_lower_reduction(): origin_mod = WithInit mod = tvm.tir.transform.LowerInitBlock()(origin_mod) + print(mod.script()) tvm.ir.assert_structural_equal(mod, WithBranch, True) diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index e55555305a09..c22f5f82ee10 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir, te +from tvm import te from tvm.script import tir as T @@ -31,12 +31,14 @@ def element_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16)) C = T.match_buffer(c, (16, 16)) B = T.alloc_buffer((16, 16)) - for i_0 in range(0, 16): - for j_0 in range(0, 16): - with T.block([16, 16]) as [i, j]: + for i0 in range(0, 16): + for j0 in range(0, 16): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) B[i, j] = A[i, j] + 1.0 - for j_0 in range(0, 16): - with T.block([16, 16]) as [i, j]: + for j0 in range(0, 16): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) C[i, j] = B[i, j] * 2.0 @@ -46,95 +48,112 @@ def transformed_element_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [16, 16]) for i_0 in range(0, 16): - with T.block([]): + with T.block(): T.reads([A[i_0, 0:16]]) T.writes([C[i_0, 0:16]]) B = T.alloc_buffer([16, 16]) for j_0 in T.serial(0, 16): - with T.block([16, 16], "") as [i, j]: - T.bind(i, i_0) - T.bind(j, j_0) + with T.block(): + i, j = T.axis.remap("SS", [i_0, j_0]) B[i, j] = A[i, j] + 1.0 for j_0 in T.serial(0, 16): - with T.block([16, 16], "") as [i, j]: - T.bind(i, i_0) - T.bind(j, j_0) + with T.block(): + i, j = T.axis.remap("SS", [i_0, j_0]) C[i, j] = B[i, j] * 2.0 @T.prim_func def original_func() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128]) as [i, j]: - A[i, j] = T.float32(0) - with T.block([32, 32, T.reduce_axis(0, 32)]) as [i, j, k]: - B = T.alloc_buffer((128, 128), "float32") - C = T.alloc_buffer((128, 128), "float32") - D = T.alloc_buffer((128, 128), "float32") - if k == 0: + for i0, j0 in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) + A[i, j] = T.float32(0) + for i0, j0, k0 in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + B = T.alloc_buffer((128, 128), "float32") + C = T.alloc_buffer((128, 128), "float32") + D = T.alloc_buffer((128, 128), "float32") + if k == 0: + for ii, jj in T.grid(4, 4): + B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in T.grid(4, 4): - B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] - for ii, jj in T.grid(4, 4): - for kk in range(0, 4): - B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] - for kk in range(0, 4): - B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += ( + D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + ) @T.prim_func def transformed_func() -> None: A = T.alloc_buffer([128, 128]) - with T.block([128, 128], "") as [i, j]: - A[i, j] = T.float32(0) - with T.block([32, 32, T.reduce_axis(0, 32)], "") as [i, j, k]: - B = T.alloc_buffer([128, 128]) - if k == 0: + for i0, j0 in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) + A[i, j] = T.float32(0) + for i0, j0, k0 in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + B = T.alloc_buffer([128, 128]) + if k == 0: + for ii, jj in T.grid(4, 4): + B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in T.grid(4, 4): - B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] - for ii, jj in T.grid(4, 4): - with T.block([], ""): - T.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) - T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) - C = T.alloc_buffer([128, 128]) - for kk in T.serial(0, 4): - B[((i * 4) + ii), ((j * 4) + jj)] = ( - B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)] - ) - for kk in T.serial(0, 4): - with T.block([], ""): - T.reads( - [ - B[((i * 4) + ii), ((j * 4) + jj)], - C[((i * 4) + ii), ((k * 4) + kk)], - ] - ) - T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) - D = T.alloc_buffer([128, 128]) - B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + ( - D[((j * 4) + jj), ((k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)] + with T.block(""): + T.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) + T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) + C = T.alloc_buffer([128, 128]) + for kk in T.serial(0, 4): + B[((i * 4) + ii), ((j * 4) + jj)] = ( + B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)] ) + for kk in T.serial(0, 4): + with T.block(""): + T.reads( + [ + B[((i * 4) + ii), ((j * 4) + jj)], + C[((i * 4) + ii), ((k * 4) + kk)], + ] + ) + T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) + D = T.alloc_buffer([128, 128]) + B[((i * 4) + ii), ((j * 4) + jj)] = B[ + ((i * 4) + ii), ((j * 4) + jj) + ] + ( + D[((j * 4) + jj), ((k * 4) + kk)] + * C[((i * 4) + ii), ((k * 4) + kk)] + ) @T.prim_func def match_buffer_func() -> None: C = T.alloc_buffer((128, 128)) - with T.block([128]) as [vi]: - C0 = T.match_buffer(C[vi, 0:128], (128)) - with T.block([128]) as [jj]: - C1 = T.match_buffer(C0[jj], ()) - C1[()] = 0 + for i in range(128): + with T.block(): + vi = T.axis.S(128, i) + C0 = T.match_buffer(C[vi, 0:128], (128)) + for j in range(128): + with T.block(): + jj = T.axis.S(128, j) + C1 = T.match_buffer(C0[jj], ()) + C1[()] = 0 @T.prim_func def transformed_match_buffer_func() -> None: for i in range(0, 128): - with T.block([128]) as [vi]: - T.bind(vi, i) + with T.block(): + vi = T.axis.S(128, i) C = T.alloc_buffer((128, 128)) C0 = T.match_buffer(C[vi, 0:128], (128)) - with T.block([128]) as [jj]: - C1 = T.match_buffer(C0[jj], ()) - C1[()] = 0 + for j in range(128): + with T.block(): + jj = T.axis.S(128, j) + C1 = T.match_buffer(C0[jj], ()) + C1[()] = 0 @T.prim_func @@ -143,9 +162,10 @@ def opaque_access(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [1024]) A_cache = T.alloc_buffer([1024]) for i in T.serial(0, 8): - with T.block([8]) as [vi]: - with T.block([8]) as [v]: - T.bind(v, vi) + with T.block(): + vi = T.axis.S(8, i) + with T.block(): + v = T.axis.S(8, vi) T.reads([A[(v * 128) : ((v * 128) + 128)]]) T.writes([A_cache[(v * 128) : ((v * 128) + 128)]]) T.evaluate( @@ -161,8 +181,8 @@ def opaque_access(a: T.handle, b: T.handle) -> None: ) ) for j in T.serial(0, 128): - with T.block([1024]) as [v]: - T.bind(v, ((vi * 128) + j)) + with T.block(): + v = T.axis.S(1024, vi * 128 + j) T.reads([A_cache[v]]) T.writes([B[v]]) B[v] = A_cache[v] @@ -173,12 +193,13 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024]) B = T.match_buffer(b, [1024]) for i in T.serial(0, 8): - with T.block([8]) as [vi]: + with T.block(): + vi = T.axis.S(8, i) T.reads(A[vi * 128 : vi * 128 + 128]) T.writes(B[vi * 128 : vi * 128 + 128]) A_cache = T.alloc_buffer([1024]) - with T.block([8]) as [v]: - T.bind(v, vi) + with T.block(): + v = T.axis.S(8, vi) T.reads([A[v * 128 : v * 128 + 128]]) T.writes([A_cache[v * 128 : v * 128 + 128]]) T.evaluate( @@ -187,8 +208,8 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: ) ) for j in T.serial(0, 128): - with T.block([1024]) as [v]: - T.bind(v, ((vi * 128) + j)) + with T.block(): + v = T.axis.S(1024, vi * 128 + j) T.reads([A_cache[v]]) T.writes([B[v]]) B[v] = A_cache[v] diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 4704b27fa5fa..105b4a2d6a3f 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -26,10 +26,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = T.float32(0) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -39,12 +41,14 @@ def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j in T.grid(32, 32): - with T.block([32, 32], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) for ii, jj in T.grid(4, 4): C[vi * 4 + ii, vj * 4 + jj] = T.float32(0) for k in range(0, 32): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) for ii, jj, kk in T.grid(4, 4, 4): C[vi * 4 + ii, vj * 4 + jj] = ( C[vi * 4 + ii, vj * 4 + jj] @@ -58,12 +62,15 @@ def elementwise_with_root(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([]) as []: - with T.block([128, 128]) as [vi, vj]: - B[vi, vj] = A[vi, vj] + T.float32(1) - - with T.block([128, 128]) as [vi, vj]: - C[vi, vj] = B[vi, vj] + T.float32(1) + with T.block() as []: + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) def func_with_opaque_block(a: T.handle, b: T.handle, c: T.handle) -> None: @@ -71,12 +78,13 @@ def func_with_opaque_block(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([]) as []: - with T.block([]) as []: + with T.block() as []: + with T.block() as []: B[0, 0] = A[0, 0] + T.float32(1) - - with T.block([128, 128]) as [vi, vj]: - C[vi, vj] = B[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) @T.prim_func @@ -85,14 +93,18 @@ def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([]) as []: - with T.block([128, 128]) as [vi, vj]: - T.reads(A[vi, vj]) - B[vi, vj] = A[vi, vj] + T.float32(1) + with T.block() as []: + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + B[vi, vj] = A[vi, vj] + T.float32(1) - with T.block([128, 128]) as [vi, vj]: - T.writes(C[vi, vj]) - C[vi, vj] = B[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) def test_complete_matmul(): @@ -181,22 +193,23 @@ def func_with_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, (1,), "int32") out_buf = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - out_buf[vi, vj] = data_buf[vi, index_buf[0]] + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + out_buf[vi, vj] = data_buf[vi, index_buf[0]] @T.prim_func def expected_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): - with T.block([16, 16], "") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, i1) + with T.block(): + vi, vj = T.axis.remap("SS", [i0, i1]) T.reads([data_buf[vi, 0:16], index_buf[0]]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[vi, index_buf[0]] @@ -208,22 +221,23 @@ def func_with_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> index_buf = T.match_buffer(index, (1,), "int32") out_buf = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] @T.prim_func def expected_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): - with T.block([16, 16], "") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, i1) + with T.block(): + vi, vj = T.axis.remap("SS", [i0, i1]) T.reads([data_buf[0:16, 0:16], index_buf[0]]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] @@ -240,11 +254,11 @@ def test_complete_buffer_indices(): def match_buffer_func(a: T.handle) -> None: A = T.match_buffer(a, (16, 16)) for i in range(0, 16): - with T.block([]): + with T.block(): A0 = T.match_buffer(A[i, 0:16], (16)) - with T.block([]): + with T.block(): for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: A1 = T.match_buffer(A0[j], ()) A1[()] = 1.0 @@ -253,15 +267,15 @@ def match_buffer_func(a: T.handle) -> None: def expected_match_buffer_func(a: T.handle) -> None: A = T.match_buffer(a, (16, 16)) for i in range(0, 16): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, 0:16]) A0 = T.match_buffer(A[i, 0:16], (16)) - with T.block([]): + with T.block(): T.reads([]) T.writes(A0[0:16]) for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads([]) T.writes(A0[j]) A1 = T.match_buffer(A0[j], ()) @@ -286,7 +300,7 @@ def alloc_buffer_func(a: T.handle, b: T.handle) -> None: def expect_alloc_buffer_func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1) B = T.match_buffer(b, [2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1) - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) C = T.alloc_buffer([2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 99a22636b927..80c37229f519 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -155,33 +155,83 @@ def test_allocate_with_buffers(): check_error(allocate_with_buffers, 2) -def inconsistent_binding() -> None: - with T.block([128, 128]) as [vi]: # error +def inconsistent_binding_value() -> None: + for i, j in T.grid(16, 16): + vi, vj = T.axis.remap("SS", [i]) # error + T.evaluate(1.0) + + +def inconsistent_binding_type() -> None: + for i, j in T.grid(16, 16): + vi, vj = T.axis.remap("S", [i, j]) # error T.evaluate(1.0) def test_inconsistent_binding(): - check_error(inconsistent_binding, 2) + check_error(inconsistent_binding_value, 3) + check_error(inconsistent_binding_type, 3) + + +def error_remap_type() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("TT", [i, j]) # error + T.evaluate(1.0) + + +def error_remap_value() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i + j, j]) # error + T.evaluate(1.0) + + +def test_error_remap_args(): + check_error(error_remap_type, 4) + check_error(error_remap_value, 4) def invalid_block_axes(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") - with T.block([A]) as [vi]: # error - T.evaluate(1.0) + for i, j in T.grid(16, 16): + with T.block(): + vi = T.axis.S(i, A) # error + T.evaluate(1.0) def test_invalid_block_axes(): - check_error(invalid_block_axes, 3) + check_error(invalid_block_axes, 5) -def miss_block_bind() -> None: - with T.block([16, 16]) as [vi, vj]: # error - T.bind(vi, 1) - T.evaluate(1.0) +def duplicate_block_axes() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi = T.axis.S(16, i) + vi = T.axis.S(16, j) # error + T.evaluate(1.0) + + +def duplicate_block_axes_remap() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vi = T.axis.remap("SS", [i, j]) # error + T.evaluate(1.0) + + +def test_duplicate_block_axes(): + check_error(duplicate_block_axes, 5) + check_error(duplicate_block_axes_remap, 4) + + +def miss_block_bind_value() -> None: + for i, j in T.grid(128, 128): + with T.block(): + vi = T.axis.S(i) # error + T.evaluate(1.0) def test_miss_block_bind(): - check_error(miss_block_bind, 2) + check_error(miss_block_bind_value, 4) def invalid_loop_var() -> None: @@ -203,74 +253,99 @@ def test_inconsistent_grid(): def invalid_match_buffer_region() -> None: - with T.block([16, 16]) as [vi, vj]: - A = T.match_buffer(vi) # error - T.evaluate(1.0) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A = T.match_buffer(vi) # error + T.evaluate(1.0) def test_invalid_match_buffer_region(): - check_error(invalid_match_buffer_region, 3) + check_error(invalid_match_buffer_region, 5) def duplicate_buffer() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - A = T.alloc_buffer((128, 128), "float32") # error - T.evaluate(1.0) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A = T.alloc_buffer((128, 128), "float32") # error + T.evaluate(1.0) def test_duplicate_buffer(): - check_error(duplicate_buffer, 4) + check_error(duplicate_buffer, 6) def duplicate_reads() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - T.reads(A[0:8, 0:8]) - T.reads(A[0:16, 0:16]) # error - T.evaluate(1.0) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[0:8, 0:8]) + T.reads(A[0:16, 0:16]) # error + T.evaluate(1.0) def duplicate_writes() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - T.writes(A[0:8, 0:8]) - T.writes(A[0:16, 0:16]) # error - T.evaluate(1.0) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.writes(A[0:8, 0:8]) + T.writes(A[0:16, 0:16]) # error + T.evaluate(1.0) def duplicate_predicate() -> None: - with T.block([16, 16]) as [vi, vj]: - T.where(1) - T.where(0) # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.where(1) + T.where(0) # error def duplicate_annotations() -> None: - with T.block([16, 16]) as [vi, vj]: - T.block_attr({}) - T.block_attr({}) # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({}) + T.block_attr({}) # error def duplicate_init() -> None: - with T.block([16, 16]) as [vi, vj]: - with T.init(): - T.evaluate(1.0) - with T.init(): # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + with T.init(): + T.evaluate(1.0) + with T.init(): # error + T.evaluate(1.0) + + +def duplicate_axes() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + vi = T.axis.S(i, 16) # error T.evaluate(1.0) def test_duplicate_block_signature(): - check_error(duplicate_reads, 5) - check_error(duplicate_writes, 5) - check_error(duplicate_predicate, 4) - check_error(duplicate_annotations, 4) - check_error(duplicate_init, 5) + check_error(duplicate_reads, 7) + check_error(duplicate_writes, 7) + check_error(duplicate_predicate, 6) + check_error(duplicate_annotations, 6) + check_error(duplicate_init, 7) + check_error(duplicate_axes, 5) def opaque_access_during_complete(a: T.handle) -> None: # error A = T.match_buffer(a, (16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - T.evaluate(T.load("float32", A.data, vi * 16 + vj)) + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.evaluate(T.load("float32", A.data, vi * 16 + vj)) def test_opaque_access_during_complete(): @@ -279,55 +354,65 @@ def test_opaque_access_during_complete(): def convert_slice_to_bufferload() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = A[vi : vi + 2, vj] + 1 # error + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A[vi : vi + 2, vj] + 1 # error def test_convert_slice_to_bufferload(): - check_error(convert_slice_to_bufferload, 4) + check_error(convert_slice_to_bufferload, 6) def error_index_type() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = A[vi, 0.0] + 1 # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A[vi, 0.0] + 1 # error def error_bufferslice_index_type() -> None: A = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - C[vi, vj] = B[vi, A[0]] # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, A[0]] # error def test_error_index_type(): - check_error(error_index_type, 4) - check_error(error_bufferslice_index_type, 6) + check_error(error_index_type, 6) + check_error(error_bufferslice_index_type, 8) def error_index_with_stop() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = A[vi, 1:10] + 1 # error + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A[vi, 1:10] + 1 # error def error_bufferslice_index_with_stop() -> None: A = T.alloc_buffer((1,), "int32") B = T.alloc_buffer((16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - C[vi, vj] = B[vi, A[0:1]] # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, A[0:1]] # error def test_error_index_with_stop_slice(): - check_error(error_index_with_stop, 4) - check_error(error_bufferslice_index_with_stop, 6) + check_error(error_index_with_stop, 6) + check_error(error_bufferslice_index_with_stop, 8) def mismatch_args() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: + with T.block(): T.reads(A[0, 0], A[1, 1]) # error T.evaluate(1.0) @@ -338,8 +423,7 @@ def test_mismatch_args(): def special_stmt_except() -> None: A = T.alloc_buffer("(128, 128)", "float32") # error - with T.block([16, 16]) as [vi, vj]: - T.evaluate(1.0) + T.evaluate(1.0) def scope_handler_except() -> None: @@ -368,7 +452,7 @@ def test_tvm_exception_catch(): def buffer_shape_mismatch(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i, j * 4 : j * 4 + 4]]) sub_A = T.match_buffer( @@ -383,7 +467,7 @@ def test_match_buffer_shape_mismatch(): def high_dim_store() -> None: - with T.block([], "root"): + with T.block("root"): B = T.allocate([256], "float32", "global") for i, j in T.grid(16, 16): B[i, j] = 1.0 # error: Store is only allowed with one index @@ -393,6 +477,15 @@ def test_high_dim_store(): check_error(high_dim_store, 5) +def block_has_option_vars() -> None: + with T.block("root") as x: # error: block does not support option_vars + T.evaluate(0.0) + + +def test_block_has_option_vars(): + check_error(block_has_option_vars, 2) + + def check_error(func, rel_lineno): # Override the default renderer to accumulate errors errors = [] @@ -416,5 +509,7 @@ def render(e): ), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}" +# TODO(Siyuan): block iter errors. + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tvmscript_ops.py b/tests/python/unittest/test_tvmscript_ops.py index 0aa043e09022..82f0fa5c86bc 100644 --- a/tests/python/unittest/test_tvmscript_ops.py +++ b/tests/python/unittest/test_tvmscript_ops.py @@ -37,22 +37,25 @@ def get_valid_counts( out_buf = T.match_buffer(out, (1, 2500, 6), "float32") out_indices_buf = T.match_buffer(out_indices, (1, 2500), "int32") - with T.block([1], "init") as [vi]: + with T.block("init"): + vi = T.axis.S(1, 0) valid_count_buf[vi] = T.int32(0) - with T.block([2500], "update") as [vj]: - T.reads([data_buf[vi, vj, 6]]) - T.writes([valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6]]) - if (data_buf[vi, vj, score_index] > score_threshold) and ( - (id_index < 0) or (data_buf[vi, vj, id_index] >= T.float32(0)) - ): - for k in T.serial(0, 6): - out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] - out_indices_buf[vi, valid_count_buf[vi]] = vj - valid_count_buf[vi] = valid_count_buf[vi] + 1 - if vj >= valid_count_buf[vi]: - for k in T.serial(0, 6): - out_buf[vi, vj, k] = T.float32(-1) - out_indices_buf[vi, vj] = T.int32(-1) + for j in range(2500): + with T.block("update"): + vj = T.axis.S(2500, j) + T.reads([data_buf[vi, vj, 6]]) + T.writes([valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6]]) + if (data_buf[vi, vj, score_index] > score_threshold) and ( + (id_index < 0) or (data_buf[vi, vj, id_index] >= T.float32(0)) + ): + for k in T.serial(0, 6): + out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] + out_indices_buf[vi, valid_count_buf[vi]] = vj + valid_count_buf[vi] = valid_count_buf[vi] + 1 + if vj >= valid_count_buf[vi]: + for k in T.serial(0, 6): + out_buf[vi, vj, k] = T.float32(-1) + out_indices_buf[vi, vj] = T.int32(-1) def _check_get_valid_counts_with_numpy(f, dshape, score_threshold, id_index, score_index): @@ -117,7 +120,7 @@ def alloc_zero_dim_buffer(a: T.handle, b: T.handle) -> None: def alloc_zero_dim_buffer_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (), "float32") B = T.match_buffer(b, (), "float32") - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) C = T.alloc_buffer((), "float32") diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 8058b96b024d..7c54cdc85f82 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -2672,10 +2672,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = T.float32(0) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -2685,11 +2687,13 @@ def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -2699,11 +2703,14 @@ def element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * T.float32(2) - - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) @T.prim_func @@ -2712,9 +2719,9 @@ def predicate(b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (16, 16), "float32") for i, jo, ji in T.grid(16, 4, 5): - with T.block([16, 16], "update") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, jo * 4 + ji) + with T.block("update"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, jo * 4 + ji) T.where(jo * 4 + ji < 16) C[vi, vj] = B[vi, vj] + T.float32(1) @@ -2807,12 +2814,16 @@ def match_buffer_region(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16), "float32") B = T.match_buffer(b, (1), "float32") - with T.block([16, 4]) as [vi, vj]: - C = T.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) - with T.block([4]) as [vii]: - D = T.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) - for i, j in T.grid(4, 4): - B[0] += D[i, 0, j] + for i, j in T.grid(16, 4): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C = T.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) + for ii in range(4): + with T.block(): + vii = T.axis.S(4, ii) + D = T.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) + for i, j in T.grid(4, 4): + B[0] += D[i, 0, j] def test_match_buffer_region(): @@ -2844,8 +2855,8 @@ def block_elements(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (1, 1), "float32") - with T.block([1], "update") as [vi]: - T.bind(vi, 0) + with T.block("update"): + vi = T.axis.S(1, 0) T.where(True) T.reads(A[0:16, 0:16]) T.writes(B[0, 0]) @@ -2879,11 +2890,11 @@ def opaque_block(a: T.handle, b: T.handle) -> None: for i in range(16): for j in range(16): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, j]) A[i, j] = T.float32(0) - with T.block([]): + with T.block(): T.reads([A[i, 0:16]]) T.writes([B[i, 0:16]]) for j in range(16): @@ -2927,7 +2938,7 @@ def rank0_block(a: T.handle) -> None: B = T.alloc_buffer((), "float32") T.store(B.data, 0, T.load("float32", A.data, 0)) - with T.block([], "update") as []: + with T.block("update") as []: T.reads([A[()]]) T.writes([B[()]]) for i in range(1): @@ -2969,8 +2980,10 @@ def test_minmax(): def abs(a: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") - with T.block([128, 128], "A") as [vi, vj]: - A[vi, vj] = T.abs(A[vi, vj]) + for i, j in T.grid(128, 128): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = T.abs(A[vi, vj]) def test_abs(): @@ -3011,15 +3024,13 @@ def test_simplify_bracket(): @T.prim_func def var_with_same_name(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = 0 - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = 0 for i, j in T.grid(16, 16): - with T.block([16, 16]) as [vi, vj]: + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = 0 for i, j in T.grid(16, 16): - with T.block([16, 16]) as [vi, vj]: + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = 0 @@ -3029,14 +3040,10 @@ def test_same_name_var(): rt_func = tvm.script.from_source(out_str) tvm.ir.assert_structural_equal(func, rt_func) - assert out_str.count("with T.block([16, 16]) as [vi, vj]") == 4 + assert out_str.count('vi, vj = T.axis.remap("SS", [i, j])') == 2 assert out_str.find("vi_") == -1 assert out_str.find("vj_") == -1 - assert out_str.count("for i0, i1 in T.grid(16, 16)") == 2 - assert out_str.find("i0_") == -1 - assert out_str.find("i1_") == -1 - assert out_str.count("for i, j in T.grid(16, 16)") == 2 assert out_str.find("i_") == -1 assert out_str.find("i_") == -1 @@ -3047,11 +3054,13 @@ def while_loop(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") i = T.alloc_buffer((), "int32", scope="local") - with T.block([16]) as [vi]: - B[vi] = 0 - while i[()] < 10: - for j in range(16): - B[j] += A[j] + for ii in range(16): + with T.block(): + vi = T.axis.S(16, ii) + B[vi] = 0 + while i[()] < 10: + for j in range(16): + B[j] += A[j] def test_while_loop(): diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh index 7138effe395a..dfd2a32165f1 100755 --- a/tests/scripts/task_ci_setup.sh +++ b/tests/scripts/task_ci_setup.sh @@ -30,7 +30,7 @@ set -o pipefail # echo "Addtiional setup in" ${CI_IMAGE_NAME} -python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.4.1 +python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.5.0 # Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in # Jenkinsfile. We expect config.cmake to be present from pack_lib(). From 75cf964b0b2d4f737b5cb25131a6c146b5edf22d Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 18 Oct 2021 17:31:33 -0400 Subject: [PATCH 51/84] Test run triage (#9308) --- .asf.yaml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/.asf.yaml b/.asf.yaml index 34e813f39639..7fd3f6930fb1 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -32,3 +32,17 @@ github: - vulkan - spirv - machine-learning + + # Triage perm for collaborators(test run) + # + # The perm is given based on needs and not based on + # evaluation of past contributions. The rationale + # is that people may need the permission to start + # contributing in this way. It serves to diversify + # the ways to contribute. + # + # There is a limited number of slots. To enable broad + # participation, permission is given on a three month + # cycle. PMC may review and recycle slots when necessary. + collaborators: + - denise-k From f095595fc6ca7b8ec760be8ae2094bff1d38ec40 Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Tue, 19 Oct 2021 01:07:01 -0700 Subject: [PATCH 52/84] [Codegen][LLVM] Add ability to turn on fast math flags (#9223) * flags to turn off and on * turn fast math on always * llvm more opts * move to default codegen opt * TODO * add fast math options to llvm target * move to using new target attributes * llvm fast math target opt code * add -O flags * fix todo lint * support llvm 4.0, 5.0 * use same opt level as target machine * revert TargetOptions * fix thing * prevent regression in llvm * togglable opt-levels Co-authored-by: Andrew Zhao Luo --- src/target/llvm/codegen_llvm.cc | 25 ++++++++++++++++-- src/target/llvm/codegen_llvm.h | 7 +++++ src/target/llvm/llvm_common.cc | 20 ++++++++++++-- src/target/llvm/llvm_module.cc | 47 ++++++++++++++++++++++++++++++++- src/target/target_kind.cc | 9 +++++++ 5 files changed, 103 insertions(+), 5 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index c94c5a685d1b..6c64f6798e47 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -77,6 +77,8 @@ void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, this->InitTarget(tm); } +void CodeGenLLVM::SetFastMathFlag(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); } + void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { module_->setTargetTriple(tm->getTargetTriple().str()); module_->setDataLayout(tm->createDataLayout()); @@ -343,7 +345,26 @@ void CodeGenLLVM::Optimize() { // place optimization pass llvm::PassManagerBuilder builder; - builder.OptLevel = 3; + + // Use the same opt-level as specified in TargetMachine for running passes + llvm::CodeGenOpt::Level opt_level = target_machine_->getOptLevel(); + + switch (opt_level) { + case llvm::CodeGenOpt::Level::None: + builder.OptLevel = 0; + break; + case llvm::CodeGenOpt::Level::Less: + builder.OptLevel = 1; + break; + + case llvm::CodeGenOpt::Level::Default: + builder.OptLevel = 2; + break; + + default: + // CodeGenOpt::Level::Aggressive + builder.OptLevel = 3; + } #if TVM_LLVM_VERSION >= 50 builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false); @@ -410,7 +431,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { } else { return etype; } -} +} // namespace codegen llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { if (auto* ptr = type.as()) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 177b53056354..4a9df65951c0 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -78,6 +78,13 @@ class CodeGenLLVM : public ExprFunctor, */ virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, bool target_c_runtime); + + /*! + * \brief Turn on fast math flags for floating point operations. + * \param fmf FastMathFlags to use for code generation. + */ + void SetFastMathFlag(llvm::FastMathFlags fmf); + /*! * \brief Compile and add function f to the current module. * \param f The function to be added. diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc index be80a8bc767e..06b2be2d9fb6 100644 --- a/src/target/llvm/llvm_common.cc +++ b/src/target/llvm/llvm_common.cc @@ -106,6 +106,8 @@ void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::stri #if TVM_LLVM_VERSION < 50 opt.LessPreciseFPMADOption = true; #endif + // In clang, these are fed from LangOpts which describe language specific features + // TODO(AndrewZhaoLuo): figure out how these relate to fast math flags opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; @@ -139,8 +141,22 @@ std::unique_ptr GetLLVMTargetMachine(const Target& target, ICHECK(allow_null) << err << " target_triple=" << target_triple; return nullptr; } - llvm::TargetMachine* tm = - llvm_target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_); + + Integer llvm_opt_level = target->GetAttr("opt-level").value_or(Integer(3)); + llvm::CodeGenOpt::Level llvm_opt; + if (llvm_opt_level <= 0) { + llvm_opt = llvm::CodeGenOpt::None; + } else if (llvm_opt_level == 1) { + llvm_opt = llvm::CodeGenOpt::Less; + } else if (llvm_opt_level == 2) { + llvm_opt = llvm::CodeGenOpt::Default; + } else { + // llvm_opt_level >= 3 + llvm_opt = llvm::CodeGenOpt::Aggressive; + } + + llvm::TargetMachine* tm = llvm_target->createTargetMachine( + target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_, llvm::CodeModel::Small, llvm_opt); return std::unique_ptr(tm); } diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 0e4bca4396f5..657778df0e93 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -258,8 +258,53 @@ class LLVMModuleNode final : public runtime::ModuleNode { // makes sense when we start to use multiple modules. cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); - cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); + // See https://llvm.org/docs/LangRef.html#fast-math-flags for details + Bool fast_math_all = target->GetAttr("fast-math").value_or(Bool(false)); + Bool fast_math_nnan = target->GetAttr("fast-math-nnan").value_or(Bool(false)); + Bool fast_math_ninf = target->GetAttr("fast-math-ninf").value_or(Bool(false)); + Bool fast_math_nsz = target->GetAttr("fast-math-nsz").value_or(Bool(false)); + Bool fast_math_arcp = target->GetAttr("fast-math-arcp").value_or(Bool(false)); + + llvm::FastMathFlags fmf; + if (fast_math_all) { +#if TVM_LLVM_VERSION >= 60 + fmf.setFast(); +#else + fmf.setUnsafeAlgebra(); +#endif + } + + if (fast_math_nnan) { + fmf.setNoNaNs(); + } + if (fast_math_ninf) { + fmf.setNoInfs(); + } + if (fast_math_nsz) { + fmf.setNoSignedZeros(); + } + if (fast_math_arcp) { + fmf.setAllowReciprocal(); + } + +#if TVM_LLVM_VERSION >= 60 + Bool fast_math_contract = target->GetAttr("fast-math-contract").value_or(Bool(false)); + Bool fast_math_afn = target->GetAttr("fast-math-afn").value_or(Bool(false)); + Bool fast_math_reassoc = target->GetAttr("fast-math-reassoc").value_or(Bool(false)); + if (fast_math_contract) { + fmf.setAllowContract(); + } + if (fast_math_afn) { + fmf.setApproxFunc(); + } + if (fast_math_reassoc) { + fmf.setAllowReassoc(); + } +#endif + cg->SetFastMathFlag(fmf); + + cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 7cd329f83738..4403af26d1a8 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -230,6 +230,15 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("link-params", Bool(false)) .add_attr_option("unpacked-api") .add_attr_option("interface-api") + // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags + .add_attr_option("fast-math") // implies all the below + .add_attr_option("fast-math-nnan") + .add_attr_option("fast-math-ninf") + .add_attr_option("fast-math-nsz") + .add_attr_option("fast-math-arcp") + .add_attr_option("fast-math-contract") + .add_attr_option("fast-math-reassoc") + .add_attr_option("opt-level") .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("c", kDLCPU) From e7a0c5cff225632f1c1927b70758bdeb6eb250c5 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 19 Oct 2021 01:10:01 -0700 Subject: [PATCH 53/84] [Profiler] Add significant VM instructions to profiling report (#9292) Added a hooks to the VM execution loop to record runtime of certain instructions with significant runtimes. Now the profiling report will include data allocation and transfer times. --- include/tvm/runtime/profiling.h | 14 ++++- include/tvm/runtime/vm/vm.h | 20 ++++++- src/runtime/profiling.cc | 46 ++++++++++++++++ src/runtime/vm/profiler/vm.cc | 54 +++++++++++++++++++ src/runtime/vm/profiler/vm.h | 2 + src/runtime/vm/vm.cc | 27 ++++++++-- .../python/unittest/test_runtime_profiling.py | 17 +++--- 7 files changed, 167 insertions(+), 13 deletions(-) diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 7ee140622bfc..7b9a68063f16 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -452,11 +452,23 @@ class CountNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(CountNode, Object); }; -/*! \brief String representation of an array or NDArray shapes +/*! \brief String representation of an array of NDArray shapes * \param shapes Array of NDArrays to get the shapes of. * \return A textual representation of the shapes. For example: `float32[2], int64[1, 2]`. */ String ShapeString(const std::vector& shapes); +/*! \brief String representation of shape encoded as an NDArray + * \param shape NDArray containing the shape. + * \param dtype The dtype of the shape. + * \return A textual representation of the shape. For example: `float32[2]`. + */ +String ShapeString(NDArray shape, DLDataType dtype); +/*! \brief String representation of a shape encoded as a vector + * \param shape Shape as a vector of integers. + * \param dtype The dtype of the shape. + * \return A textual representation of the shape. For example: `float32[2]`. + */ +String ShapeString(const std::vector& shape, DLDataType dtype); } // namespace profiling } // namespace runtime diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 831336b9dbfe..039b1894d7c4 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -198,14 +198,14 @@ class VirtualMachine : public runtime::ModuleNode { * \param reg The register to read from. * \return The read object. */ - inline ObjectRef ReadRegister(RegName reg) const; + ObjectRef ReadRegister(RegName reg) const; /*! * \brief Read a VM register and cast it to int32_t * \param reg The register to read from. * \return The read scalar. */ - inline int64_t LoadScalarInt(RegName reg) const; + int64_t LoadScalarInt(RegName reg) const; /*! * \brief Invoke a VM function. @@ -268,6 +268,22 @@ class VirtualMachine : public runtime::ModuleNode { */ void SetInput(std::string name, TVMArgs args, int offset); + /*! + * \brief Internal hook for profiling the start of an op. + * + * This hook is only called on certain ops that are likely to take a + * significant amount of runtime (normally because they alloc or transfer to + * device). + * + * \param instr Instruction that will be executed after this hook fires + */ + virtual void OpStartHook(Instruction instr); + + /*! + * \brief Internal hook for profiling the end of an op. + */ + virtual void OpStopHook(); + protected: /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs_; diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 8b37bfcd539d..a1d06fc8cab8 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -160,6 +161,51 @@ void Profiler::Stop() { } } +std::vector ToShape(NDArray shape_tensor) { + std::vector shape; + auto rank = shape_tensor.Shape().size(); + auto dtype = shape_tensor.DataType(); + + // For 0-rank shapes we need to allocate a single scalar. + if (rank == 0) { + return shape; + } + + // Otherwise we should be rank-1, and we will extract the number of dimensions + // for the output vector. + ICHECK_EQ(rank, 1U) << "shape tensor should be a k-length vector, found " << rank; + int64_t ndim = shape_tensor.Shape().at(0); + shape.resize(ndim); + + const DLTensor* dl_tensor = shape_tensor.operator->(); + if (dtype.is_int() && dtype.bits() == 32 && dtype.lanes() == 1) { + int32_t* dims = reinterpret_cast(dl_tensor->data); + shape.assign(dims, dims + ndim); + } else if (dtype.is_int() && dtype.bits() == 64 && dtype.lanes() == 1) { + int64_t* dims = reinterpret_cast(dl_tensor->data); + shape.assign(dims, dims + ndim); + } else { + LOG(FATAL) << "invalid shape tensor datatype: " << dtype; + } + + return shape; +} + +String ShapeString(NDArray shape, DLDataType dtype) { return ShapeString(ToShape(shape), dtype); } + +String ShapeString(const std::vector& shape, DLDataType dtype) { + std::stringstream sizes; + sizes << dtype << "["; + for (size_t i = 0; i < shape.size(); i++) { + if (i != 0) { + sizes << ", "; + } + sizes << shape[i]; + } + sizes << "]"; + return String(sizes.str()); +} + String ShapeString(const std::vector& shapes) { std::stringstream sizes; for (const NDArray& ary : shapes) { diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index d6575c35d10d..cd2d1332580b 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -25,6 +25,7 @@ #include "vm.h" #include +#include #include #include @@ -32,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -96,6 +98,58 @@ void VirtualMachineDebug::LoadExecutable(const Executable* exec) { } } +void VirtualMachineDebug::OpStartHook(Instruction instr) { + if (prof_ && prof_.operator*().IsRunning()) { + if (instr.op == Opcode::LoadConst) { + Device dev = GetDevice(exec_->const_device_type[instr.const_index]); + prof_.operator*().StartCall("VM::LoadConst", dev, {}); + } else if (instr.op == Opcode::DeviceCopy) { + Device dst_dev; + dst_dev.device_type = static_cast(instr.dst_device_type); + dst_dev.device_id = 0; + prof_.operator*().StartCall("VM::DeviceCopy", dst_dev, {}); + } else if (instr.op == Opcode::ReshapeTensor) { + prof_.operator*().StartCall("VM::ReshapeTensor", devices_[1], {}); + } else if (instr.op == Opcode::AllocTensor) { + auto shape = std::vector(instr.alloc_tensor.ndim); + + for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { + shape[i] = instr.alloc_tensor.shape[i]; + } + auto storage_obj = ReadRegister(instr.alloc_tensor.storage); + auto storage = Downcast(storage_obj); + prof_.operator*().StartCall( + "VM::AllocTensor", storage->buffer.device, + {{"Argument Shapes", profiling::ShapeString(shape, instr.alloc_tensor.dtype)}}); + } else if (instr.op == Opcode::AllocTensorReg) { + auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); + auto storage = Downcast(storage_obj); + Device cpu_dev = GetDevice(static_cast(kDLCPU)); + auto shape_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); + NDArray shape_tensor = Downcast(shape_obj).CopyTo(cpu_dev); + prof_.operator*().StartCall( + "VM::AllocTensorReg", storage->buffer.device, + {{"Argument Shapes", + profiling::ShapeString(shape_tensor, instr.alloc_tensor_reg.dtype)}}); + } else if (instr.op == Opcode::AllocStorage) { + auto size = LoadScalarInt(instr.alloc_storage.allocation_size); + std::ostringstream shape; + shape << DLDataType2String(instr.alloc_storage.dtype_hint) << "[" << size << "]"; + prof_.operator*().StartCall("VM::AllocStorage", + {static_cast(instr.alloc_storage.device_type), 0}, + {{"VM::Argument Shapes", String(shape.str())}}); + } else { + prof_.operator*().StartCall("VM::UnknownOp", devices_[1], {}); + } + } +} + +void VirtualMachineDebug::OpStopHook() { + if (prof_ && prof_.operator*().IsRunning()) { + prof_.operator*().StopCall(); + } +} + void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) { ICHECK(exec_); diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index 1efefda52b97..4325fa8a7999 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -51,6 +51,8 @@ class VirtualMachineDebug : public VirtualMachine { private: void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) final; + void OpStartHook(Instruction instr) final; + void OpStopHook() final; std::unordered_map packed_index_map_; dmlc::optional prof_; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index c7a1baa1430d..addd5ca5d861 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -113,6 +113,9 @@ std::vector ToShape(NDArray shape_tensor) { return shape; } +void VirtualMachine::OpStartHook(Instruction instr) {} +void VirtualMachine::OpStopHook() {} + PackedFunc VirtualMachine::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "invoke") { @@ -400,11 +403,9 @@ inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { frames_.back().register_file[r] = val; } -inline ObjectRef VirtualMachine::ReadRegister(Index r) const { - return frames_.back().register_file[r]; -} +ObjectRef VirtualMachine::ReadRegister(Index r) const { return frames_.back().register_file[r]; } -inline int64_t VirtualMachine::LoadScalarInt(Index r) const { +int64_t VirtualMachine::LoadScalarInt(Index r) const { int64_t result = 0; const auto& obj = ReadRegister(r); NDArray array = Downcast(CopyTo(obj, {kDLCPU, 0})); @@ -458,6 +459,11 @@ void VirtualMachine::RunLoop() { throw std::runtime_error("VM encountered fatal error"); } case Opcode::LoadConst: { + bool is_not_cached = const_pool_.size() <= static_cast(instr.const_index) || + !const_pool_[instr.const_index].defined(); + if (is_not_cached) { + OpStartHook(instr); + } auto constant_obj = exec_->constants[instr.const_index]; // We cache the allocated object in the constant pool. To measure, the // first iteration will set the pool up. The other iterations will @@ -471,6 +477,9 @@ void VirtualMachine::RunLoop() { const_pool_[instr.const_index] = CopyTo(constant_obj, dev); } WriteRegister(instr.dst, const_pool_[instr.const_index]); + if (is_not_cached) { + OpStopHook(); + } pc_++; goto main_loop; } @@ -560,6 +569,7 @@ void VirtualMachine::RunLoop() { goto main_loop; } case Opcode::AllocTensor: { + OpStartHook(instr); auto shape = std::vector(instr.alloc_tensor.ndim); for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { @@ -572,10 +582,12 @@ void VirtualMachine::RunLoop() { auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor.dtype); WriteRegister(instr.dst, obj); + OpStopHook(); pc_++; goto main_loop; } case Opcode::AllocTensorReg: { + OpStartHook(instr); Device cpu_dev = GetDevice(static_cast(kDLCPU)); auto shape_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); NDArray shape_tensor = Downcast(CopyTo(shape_obj, cpu_dev)); @@ -586,6 +598,7 @@ void VirtualMachine::RunLoop() { auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor_reg.dtype); WriteRegister(instr.dst, obj); + OpStopHook(); pc_++; goto main_loop; } @@ -609,6 +622,7 @@ void VirtualMachine::RunLoop() { goto main_loop; } case Opcode::AllocStorage: { + OpStartHook(instr); auto size = LoadScalarInt(instr.alloc_storage.allocation_size); auto alignment = instr.alloc_storage.alignment; @@ -625,6 +639,7 @@ void VirtualMachine::RunLoop() { storage_obj->buffer = alloc->Alloc(size, alignment, instr.alloc_storage.dtype_hint); Storage storage(storage_obj); WriteRegister(instr.dst, storage); + OpStopHook(); pc_++; goto main_loop; } @@ -656,6 +671,7 @@ void VirtualMachine::RunLoop() { } } case Opcode::ReshapeTensor: { + OpStartHook(instr); Device cpu_dev = GetDevice(static_cast(kDLCPU)); auto tensor_obj = ReadRegister(instr.reshape_tensor.tensor); NDArray tensor_arr = Downcast(tensor_obj); @@ -671,10 +687,12 @@ void VirtualMachine::RunLoop() { // Reshape the input tensor auto out_tensor = tensor_arr.CreateView(shape, tensor_arr->dtype); WriteRegister(instr.dst, out_tensor); + OpStopHook(); pc_++; goto main_loop; } case Opcode::DeviceCopy: { + OpStartHook(instr); auto tensor_src = ReadRegister(instr.src); NDArray src_data = Downcast(tensor_src); Device src_dev = src_data->device; @@ -686,6 +704,7 @@ void VirtualMachine::RunLoop() { NDArray dst_data = src_data.CopyTo(dst_dev); WriteRegister(instr.dst, dst_data); + OpStopHook(); pc_++; goto main_loop; } diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index ca6cb0181489..3e38a526855a 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -52,15 +52,20 @@ def read_csv(report): @pytest.mark.skipif(not profiler_vm.enabled(), reason="VM Profiler not enabled") @tvm.testing.parametrize_targets def test_vm(target, dev): - mod, params = mlp.get_workload(1) - - exe = relay.vm.compile(mod, target, params=params) + dtype = "float32" + x = relay.var("x", shape=(relay.Any(), relay.Any()), dtype=dtype) + y = relay.var("y", shape=(relay.Any(), relay.Any()), dtype=dtype) + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], relay.add(x, y)) + exe = relay.vm.compile(mod, target) vm = profiler_vm.VirtualMachineProfiler(exe, dev) - data = np.random.rand(1, 1, 28, 28).astype("float32") - report = vm.profile(data, func_name="main") - assert "fused_nn_softmax" in str(report) + data = np.random.rand(28, 28).astype("float32") + report = vm.profile(data, data, func_name="main") + assert "fused_add" in str(report) assert "Total" in str(report) + assert "AllocTensorReg" in str(report) + assert "AllocStorage" in str(report) csv = read_csv(report) assert "Hash" in csv.keys() From 0147b0445dc55633b05f4491072852fd4e2ce835 Mon Sep 17 00:00:00 2001 From: Chris Hoge Date: Tue, 19 Oct 2021 06:55:11 -0700 Subject: [PATCH 54/84] Fix direct and broken links (#9314) Updates links to use references instead of direct links, fixing broken links and making all internal docs links more durable to refactoring --- docs/dev/how_to/relay_add_op.rst | 18 ++++++++---------- docs/how_to/deploy/arm_compute_lib.rst | 6 +++--- docs/reference/api/python/tir.rst | 2 ++ docs/topic/vta/install.rst | 10 ++++------ .../deploy_prequantized_tflite.py | 10 +++++----- .../work_with_schedules/schedule_primitives.py | 2 ++ gallery/tutorial/autotvm_relay_x86.py | 12 +++++------- gallery/tutorial/install.py | 4 ++-- gallery/tutorial/intro_topi.py | 2 ++ gallery/tutorial/tensor_expr_get_started.py | 2 +- gallery/tutorial/tvmc_command_line_driver.py | 8 +++----- vta/tutorials/README.txt | 2 ++ 12 files changed, 39 insertions(+), 39 deletions(-) diff --git a/docs/dev/how_to/relay_add_op.rst b/docs/dev/how_to/relay_add_op.rst index f9ade45f0800..2a8c771dc63d 100644 --- a/docs/dev/how_to/relay_add_op.rst +++ b/docs/dev/how_to/relay_add_op.rst @@ -190,18 +190,16 @@ useful for fusing operators. ``kOpaque`` tells TVM to not bother trying to fuse While we've now defined the interface for our operations we still need to define how to perform the actual calculations for cumulative sum and product. -Writing this code is outside the scope of the tutorial. For now, we assume -we have a well tested implementation for the operation's compute. For -more details on how to do this, we recommend looking up the tutorials -on `tensor expressions`_, `TVM's operator inventory (topi)`_ and looking at the -example cumulative sum and product implementations found in `python/tvm/topi/scan.py`_ -and the gpu versions in `python/tvm/topi/cuda/scan.py`_. In the case of our cumulative -sum and product operations we write things directly in `TIR`_ which is the +Writing this code is outside the scope of the tutorial. For now, we assume we +have a well tested implementation for the operation's compute. For more details +on how to do this, we recommend looking up the tutorials on :ref:`tensor +expressions `, :ref:`TVM's operator inventory +(topi) ` and looking at the example cumulative sum and product +implementations found in `python/tvm/topi/scan.py`_ and the gpu versions in +`python/tvm/topi/cuda/scan.py`_. In the case of our cumulative sum and product +operations we write things directly in :ref:`TIR ` which is the representation where tensor expressions and topi will lower into. -.. _tensor expressions: https://tvm.apache.org/docs/tutorials/get_started/tensor_expr_get_started.html -.. _TVM's operator inventory (topi): https://tvm.apache.org/docs/tutorials/topi/intro_topi.html -.. _TIR: https://tvm.apache.org/docs/dev/index.html?highlight=tir#tvm-tir .. _python/tvm/topi/scan.py: https://github.com/apache/tvm/blob/main/python/tvm/topi/scan.py .. _python/tvm/topi/cuda/scan.py: https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/scan.py diff --git a/docs/how_to/deploy/arm_compute_lib.rst b/docs/how_to/deploy/arm_compute_lib.rst index 6fb531a0a8f6..831438273cca 100644 --- a/docs/how_to/deploy/arm_compute_lib.rst +++ b/docs/how_to/deploy/arm_compute_lib.rst @@ -142,9 +142,9 @@ Export the module. lib.export_library(lib_path, cc=cross_compile) -Run Inference. This must be on an Arm device. If compiling on x86 device and running on AArch64, -consider using the RPC mechanism. Tutorials for using the RPC mechanism: -https://tvm.apache.org/docs/tutorials/get_started/cross_compilation_and_rpc.html +Run Inference. This must be on an Arm device. If compiling on x86 device and +running on AArch64, consider using the RPC mechanism. :ref:`Tutorials for using +the RPC mechanism ` .. code:: python diff --git a/docs/reference/api/python/tir.rst b/docs/reference/api/python/tir.rst index b0b8f1cff5fb..2152be69ea6f 100644 --- a/docs/reference/api/python/tir.rst +++ b/docs/reference/api/python/tir.rst @@ -15,6 +15,8 @@ specific language governing permissions and limitations under the License. +.. _api-python-tir: + tvm.tir ------- .. automodule:: tvm.tir diff --git a/docs/topic/vta/install.rst b/docs/topic/vta/install.rst index 2248975b61b1..e4b309ea9b61 100644 --- a/docs/topic/vta/install.rst +++ b/docs/topic/vta/install.rst @@ -30,8 +30,8 @@ We present three installation guides, each extending on the previous one: VTA Simulator Installation -------------------------- -You need `TVM installed `_ on your machine. -For a quick and easy start, checkout the `Docker Guide `_. +You need :ref:`TVM installed ` on your machine. For a quick and +easy start, checkout the :ref:`Docker Guide `. You'll need to set the following paths to use VTA: @@ -65,7 +65,7 @@ To ensure that you've properly installed the VTA python package, run the followi python /vta/tests/python/integration/test_benchmark_topi_conv2d.py -You are invited to try out our `VTA programming tutorials `_. +You are invited to try out our :ref:`VTA programming tutorials `. **Note**: You'll notice that for every convolution layer, the throughput gets reported in GOPS. These numbers are actually the computational throughput that the simulator achieves, by evaluating the convolutions in software. @@ -222,9 +222,7 @@ The performance metrics measured on the Pynq board will be reported for each con **Tip**: You can track progress of the FPGA programming and the runtime rebuilding steps by looking at the RPC server's logging messages in your Pynq ``ssh`` session. -You can also try out our `VTA programming tutorials `_. - - +You can also try out our :ref:`VTA programming tutorials `. Intel DE10 FPGA Setup --------------------- diff --git a/gallery/how_to/deploy_models/deploy_prequantized_tflite.py b/gallery/how_to/deploy_models/deploy_prequantized_tflite.py index 7bbb06bdf801..830e2ab07466 100644 --- a/gallery/how_to/deploy_models/deploy_prequantized_tflite.py +++ b/gallery/how_to/deploy_models/deploy_prequantized_tflite.py @@ -255,8 +255,8 @@ def run_tvm(lib): # * Set the environment variable TVM_NUM_THREADS to the number of physical cores # * Choose the best target for your hardware, such as "llvm -mcpu=skylake-avx512" or # "llvm -mcpu=cascadelake" (more CPUs with AVX512 would come in the future) -# * Perform autotuning - `Auto-tuning a convolution network for x86 CPU -# `_. -# * To get best inference performance on ARM CPU, change target argument according to your -# device and follow `Auto-tuning a convolution network for ARM CPU -# `_. +# * Perform autotuning - :ref:`Auto-tuning a convolution network for x86 CPU +# `. +# * To get best inference performance on ARM CPU, change target argument +# according to your device and follow :ref:`Auto-tuning a convolution +# network for ARM CPU `. diff --git a/gallery/how_to/work_with_schedules/schedule_primitives.py b/gallery/how_to/work_with_schedules/schedule_primitives.py index ade79f69707f..65fdeda57c3b 100644 --- a/gallery/how_to/work_with_schedules/schedule_primitives.py +++ b/gallery/how_to/work_with_schedules/schedule_primitives.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """ +.. _schedule_primitives: + Schedule Primitives in TVM ========================== **Author**: `Ziheng Jiang `_ diff --git a/gallery/tutorial/autotvm_relay_x86.py b/gallery/tutorial/autotvm_relay_x86.py index 67faec4505a6..8b9c45c2a859 100644 --- a/gallery/tutorial/autotvm_relay_x86.py +++ b/gallery/tutorial/autotvm_relay_x86.py @@ -81,10 +81,9 @@ # # .. note:: Working with Other Model Formats # -# TVM supports many popular model formats. A list can be found in the `Compile -# Deep Learning Models -# `_ -# section of the TVM Documentation. +# TVM supports many popular model formats. A list can be found in the +# :ref:`Compile Deep Learning Models ` section of the TVM +# Documentation. model_url = "".join( [ @@ -150,9 +149,8 @@ # # Specifying the correct target can have a huge impact on the performance of # the compiled module, as it can take advantage of hardware features -# available on the target. For more information, please refer to `Auto-tuning -# a convolutional network for x86 CPU -# `_. +# available on the target. For more information, please refer to +# :ref:`Auto-tuning a convolutional network for x86 CPU `. # We recommend identifying which CPU you are running, along with optional # features, and set the target appropriately. For example, for some # processors ``target = "llvm -mcpu=skylake"``, or ``target = "llvm diff --git a/gallery/tutorial/install.py b/gallery/tutorial/install.py index b69b8b493a4f..67ce093b9d7f 100644 --- a/gallery/tutorial/install.py +++ b/gallery/tutorial/install.py @@ -35,8 +35,8 @@ # allow you to enable specific features such as GPU support, microcontroller # support (microTVM), and a debugging runtime, and other features. You will also # want to install from source if you want to actively contribute to the TVM -# project. The full instructions are on the `Install TVM From Source -# `_ page. +# project. The full instructions are on the :ref:`Install TVM From Source +# ` page. ################################################################################ # Installing From Binary Packages diff --git a/gallery/tutorial/intro_topi.py b/gallery/tutorial/intro_topi.py index 8138e4718cd9..dad8c53bf4ae 100644 --- a/gallery/tutorial/intro_topi.py +++ b/gallery/tutorial/intro_topi.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """ +.. _tutorial-topi: + Introduction to TOPI ==================== **Author**: `Ehsan M. Kermani `_ diff --git a/gallery/tutorial/tensor_expr_get_started.py b/gallery/tutorial/tensor_expr_get_started.py index 310d6bdbfee4..fda332cb63ba 100644 --- a/gallery/tutorial/tensor_expr_get_started.py +++ b/gallery/tutorial/tensor_expr_get_started.py @@ -512,7 +512,7 @@ def evaluate_addition(func, target, optimization, log): # before it moves on to the next stage. # # A complete description of these primitives can be found in the -# [Schedule Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html) docs page. +# :ref:`Schedule Primitives ` docs page. ################################################################################ # Example 2: Manually Optimizing Matrix Multiplication with TE diff --git a/gallery/tutorial/tvmc_command_line_driver.py b/gallery/tutorial/tvmc_command_line_driver.py index ea3254054ecf..7a0b97895e4f 100644 --- a/gallery/tutorial/tvmc_command_line_driver.py +++ b/gallery/tutorial/tvmc_command_line_driver.py @@ -154,11 +154,9 @@ # Specifying the correct target (option ``--target``) can have a huge # impact on the performance of the compiled module, as it can take # advantage of hardware features available on the target. For more -# information, please refer to `Auto-tuning a convolutional network -# for x86 CPU `_. -# We recommend identifying which CPU you are running, along with optional features, -# and set the target appropriately. -# +# information, please refer to :ref:`Auto-tuning a convolutional network for +# x86 CPU `. We recommend identifying which CPU you are +# running, along with optional features, and set the target appropriately. ################################################################################ # Running the Model from The Compiled Module with TVMC diff --git a/vta/tutorials/README.txt b/vta/tutorials/README.txt index 3d3858b111ba..c1ff4ca0444d 100644 --- a/vta/tutorials/README.txt +++ b/vta/tutorials/README.txt @@ -1,3 +1,5 @@ +.. _vta-tutorials: + VTA Tutorials ============= This page contains tutorials about VTA and how to use TVM/Relay to target VTA. From 31c171ebebf522f041c4d52c2e8454afb25e285e Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Tue, 19 Oct 2021 11:24:50 -0700 Subject: [PATCH 55/84] [Keras] Support return_sequences in LSTM (#9303) --- python/tvm/relay/frontend/keras.py | 6 +++++- tests/python/frontend/keras/test_forward.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index aa185923d02e..bf6293a2a90c 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -896,6 +896,7 @@ def _convert_lstm(inexpr, keras_layer, etab): in_data = _op.squeeze(in_data, axis=[0]) in_data = _op.split(in_data, indices_or_sections=time_steps, axis=0) # loop for the number of time_steps + out_list = [] # store h outputs in case return_sequences is True for data in in_data: ixh1 = _op.nn.dense(data, kernel_weight, units=units) ixh2 = _op.nn.bias_add(_op.nn.dense(next_h, recurrent_weight, units=units), bias=in_bias) @@ -906,8 +907,11 @@ def _convert_lstm(inexpr, keras_layer, etab): next_c = in_transform * next_c + in_gate * _convert_activation(gates[2], keras_layer, None) out_gate = _convert_recurrent_activation(gates[3], keras_layer) next_h = out_gate * _convert_activation(next_c, keras_layer, None) + if keras_layer.return_sequences: + out_list.append(_op.expand_dims(next_h, axis=1)) + out = _op.concatenate(out_list, axis=1) if keras_layer.return_sequences else next_h out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) - out = _op.reshape(next_h, newshape=out_shape) + out = _op.reshape(out, newshape=out_shape) return [out, next_h, next_c] diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 26bf58cbf384..4dfe89fe40e5 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -417,6 +417,17 @@ def test_forward_reuse_layers(self, keras): keras_model = keras.models.Model(data, z) verify_keras_frontend(keras_model) + def test_forward_lstm(self, keras): + data = keras.layers.Input(shape=(10, 32)) + rnn_funcs = [ + keras.layers.LSTM(16), + keras.layers.LSTM(16, return_sequences=True), + ] + for rnn_func in rnn_funcs: + x = rnn_func(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + def test_forward_rnn(self, keras): data = keras.layers.Input(shape=(1, 32)) rnn_funcs = [ @@ -613,6 +624,7 @@ def test_forward_nested_layers(self, keras): sut.test_forward_multi_inputs(keras=k) sut.test_forward_multi_outputs(keras=k) sut.test_forward_reuse_layers(keras=k) + sut.test_forward_lstm(keras=k) sut.test_forward_rnn(keras=k) sut.test_forward_vgg16(keras=k) sut.test_forward_vgg16(keras=k, layout="NHWC") From 6701b78c8b91b5aea70ef4bc11839eda38af90bc Mon Sep 17 00:00:00 2001 From: Mastize Date: Wed, 20 Oct 2021 06:09:23 +0800 Subject: [PATCH 56/84] fix missing span arg (#9318) --- python/tvm/relay/quantize/quantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 3b4d97576cd7..7f4724db22b2 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -51,7 +51,7 @@ def kind2str(kind): def _forward_op(ref_call, args): """forward the operator of ref_call with provided arguments""" - return _expr.Call(ref_call.op, args, ref_call.attrs, ref_call.type_args) + return _expr.Call(ref_call.op, args, ref_call.attrs, ref_call.type_args, ref_call.span) @tvm._ffi.register_object("relay.quantize.QConfig") From 0a5a029e485c8dc35ca33257d477b578b37e39e1 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 19 Oct 2021 15:56:39 -0700 Subject: [PATCH 57/84] [Community] @elvin-n -> Reviewer (#9321) --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index b9ef0479c72f..19287b4cbfd5 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -111,6 +111,7 @@ We do encourage everyone to work anything they are interested in. - [Andrew Z. Luo](https://github.com/AndrewZhaoLuo): @AndrewZhaoLuo - [Steven Lyubomirsky](https://github.com/slyubomirsky): @slyubomirsky - [Masahiro Masuda](https://github.com/masahi): @masahi +- [Andrey Malyshev](https://github.com/elvin-n): @elvin-n - [Sergey Mironov](https://github.com/grwlf): @grwlf - [Thierry Moreau](https://github.com/tmoreau89): @tmoreau89 - [Kazutaka Morita](https://github.com/kazum): @kazum From af09ac917f201f8898c78af43129b0e1018a32d3 Mon Sep 17 00:00:00 2001 From: Adam Straw Date: Wed, 20 Oct 2021 08:14:07 -0700 Subject: [PATCH 58/84] Adjust Hexagon conv2d schedule to split channel out (k) and move to outer loop (#9287) * Adjust Hexagon conv2d schedule to split channel out (k) and move to outermost loop * add missing reference data verify --- tests/python/contrib/test_hexagon/README.md | 448 +++++++++++------- .../test_hexagon/test_conv2d_blocked.py | 86 +++- 2 files changed, 346 insertions(+), 188 deletions(-) diff --git a/tests/python/contrib/test_hexagon/README.md b/tests/python/contrib/test_hexagon/README.md index 1d6a298d48d6..a47c3438bf57 100644 --- a/tests/python/contrib/test_hexagon/README.md +++ b/tests/python/contrib/test_hexagon/README.md @@ -29,14 +29,14 @@ Documents manual TE schedule to illustrate Hexagon operator slicing. * Added spacing and line breaks * Naming conventions * Using input (instead of activation) - * Using kernel (instead of weight, filter) + * Using filter (instead of weight, kernel) * Using `k` to denote channel-out and `c` or `rc` (reduction channel) to denote channel-in - * Using `rh` and `rw` (reduction height / width) to denote kernel height and width + * Using `rh` and `rw` (reduction height / width) to denote filter height and width # Calling Convention TODO: Map this packed string to parameters -conv2d_packed_filter-1-1-0-float32-1-1-64-64-64-llvm +conv2d_packed_filter-1-1-0-float32-1-1-1-64-64-128-llvm # Baseline conv2d @@ -44,70 +44,80 @@ This is a baseline 1x1 conv2d schedule for Hexagon. ## Command -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-1-1-64-64-64-llvm]" +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-1-1-1-64-64-128-llvm]" ## Parameters | Parameter | Value | | --------- | ----------- | | Batch | 1 | -| Kernel | 1x1 | +| Filter | 1x1 | | Spatial | 64x64 | | Input Ch | 64 | -| Output Ch | 64 | +| Output Ch | 128 | | Stride | 1 | | Padding | 0 | | Layout | NHWC8h8w32c | ## Assumptions -* Microkernels will compute "full depth" in channel-out (k) dimension. - * The compute schedule (see TIR below) - * Places the outer channel-out loop over `ko` inside the outer width loop over `wo` - * Encodes the assumption that Hexagon microkernels will compute "full depth" in the channel-out (k) dimension +* Pattern matching for microkernels is not senstive to cache reads and writes between the outer height (ho) and outer width (wo) loops. ## To Do -* Adjust compute schedule and add kernel cache read once Hexagon microkernel semantics are understood - +* n/a + ## Annotated TIR ``` -primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> () +primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} - buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 2, 8, 8, 32], []), // NHWC8h8w32c - kernel_buffer: Buffer(kernel_pointer: Pointer(float32), float32, [2, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i - input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending layout RFC) - buffer_map = {input_handle: input_buffer, kernel_handle: kernel_buffer, output_handle: output_buffer} { - + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c + filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) + buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { allocate(input.cache: Pointer(global float32), float32, [32768]), storage_scope = global; - allocate(output.cache: Pointer(global float32), float32, [32768]), storage_scope = global; + allocate(filter.cache: Pointer(global float32), float32, [2048]), storage_scope = global; + allocate(output.cache: Pointer(global float32), float32, [16384]), storage_scope = global; + + for (ko.outer: int32, 0, 4) { + for (ho.outer: int32, 0, 8) { + + // input cache read + // NHWC -> NHWC8h8w32c (pending RFC) + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[(((((wo*4096) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[((((((ho.outer*32768) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } - for (ho.outer: int32, 0, 8) { - // cache read - // NHWC -> NHWC8h8w32c (pending layout RFC) - for (wo: int32, 0, 8) { + // filter cache read for (co: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - input.cache[(((((wo*4096) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)input_pointer[((((((ho.outer*32768) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + for (ci8: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (ci4: int32, 0, 4) { + filter.cache[((((co*1024) + (ci8*128)) + (ki*4)) + ci4)] = + (float32*)filter_pointer[(((((ko.outer*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] } } } } - } - // compute - for (wo.c: int32, 0, 8) { - for (ko.c: int32, 0, 2) { - + // compute + for (wo.c: int32, 0, 8) { + // init output cache for (hi.c.init: int32, 0, 8) { for (wi.c.init: int32, 0, 8) { for (ki.c.init: int32, 0, 32) { - output.cache[(((((wo.c*4096) + (ko.c*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + output.cache[((((wo.c*2048) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 } } } @@ -118,173 +128,220 @@ primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> () for (wi.c: int32, 0, 8) { for (ki.c: int32, 0, 32) { for (rc.inner: int32, 0, 32) { - output.cache[(((((wo.c*4096) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] = ( - (float32*)output.cache[(((((wo.c*4096) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + (float32*)output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] + ( (float32*)input.cache[(((((wo.c*4096) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * - (float32*)kernel_pointer[(((((ko.c*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + (float32*)filter.cache[((((rc.outer*1024) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] ) ) } } } } - } // end rc.outer - } // end ko.c - } // end wo.c + } + } // end wo.c - // cache write - for (wo: int32, 0, 8) { - for (ko: int32, 0, 2) { + // cache write + for (wo: int32, 0, 8) { for (hi: int32, 0, 8) { for (wi: int32, 0, 8) { for (ki: int32, 0, 32) { - output_pointer[((((((ho.outer*32768) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] = - (float32*)output.cache[(((((wo*4096) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] + output_pointer[((((((ho.outer*65536) + (wo*8192)) + (ko.outer*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((wo*2048) + (hi*256)) + (wi*32)) + ki)] } } } } - } - } + } // end ho.outer + } // end ko.outer } ``` -# Split on Height - "Full Output Slice" +# Split on Channel Out and Height - "Full Output Slice" -Adds a new parameter `h_split` which creates a loop split on the height `h` dimension. The cache reads and writes are moved to the outer of the two loops created by that split - the loop over `ho.outer`. This increases cache usage by a factor equivalent to `h_split`. The compute is still "full width" and "full depth" in the channel-out dimension and now over multiple slices in the height `h` dimension. +Adds new parameters `k_split` and `h_split` which creates a loop split on the outer channel out `ko` and height `ho` loops creating `outer` and `inner` loops for each split. The cache reads and writes are computed at `ho.outer` which means that cache allocation grow in proportion to `k_split` and `h_split` factors. -The key changes in TIR versus the baseline are ... +The key changes in TIR versus the above are... 1) Increased cache allocations: ``` + // input cache grows by factor of h_split = 2 allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // filter cache grows by factor of k_split = 2 + allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global; + + // output cache grows by factor of h_split * k_split = 4 allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; ``` -2) The loop split on the `h` dimension: +2) Outer loop splits using k_split and h_split factors ``` - for (ho.outer: int32, 0, 4) { - for (ho.inner: int32, 0, 2) { + // ko.outer = outer loop split on ko using k_split factor + for (ko.outer: int32, 0, 2) { + // ho.outer = outer loop split on ho using h_split factor + for (ho.outer: int32, 0, 4) { +``` + +3) Inner loop splits in both cache read / write and compute schedules. This is taken from the compute schedule e.g. +``` + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { ``` ## Command -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-2-1-64-64-64-llvm]" +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-2-2-1-64-64-128-llvm]" ## Parameters | Parameter | Value | | --------- | ----------- | | Batch | 1 | -| Kernel | 1x1 | +| Filter | 1x1 | | Spatial | 64x64 | | Input Ch | 64 | -| Output Ch | 64 | +| Output Ch | 128 | | Stride | 1 | | Padding | 0 | | Layout | NHWC8h8w32c | +| k_split | 2 | | h_split | 2 | ## Assumptions -Same as baseline +* n/a - With the loop splits on `ko` and `ho` the compute schedule is now over `ko.inner` `ho.inner` `wo` etc. This should fit the pattern matching for microkernels. ## To Do -Same as baseline +* n/a ## Annotated TIR ``` -primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> () +primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} - buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 2, 8, 8, 32], []), - kernel_buffer: Buffer(kernel_pointer: Pointer(float32), float32, [2, 2, 1, 1, 8, 32, 4], []), - input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} - buffer_map = {input_handle: input_buffer, kernel_handle: kernel_buffer, output_handle: output_buffer} { - - // increased cache usage due to h_split parameter + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c + filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) + buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { + + // input cache grows by factor of h_split = 2 allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // filter cache grows by factor of k_split = 2 + allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global; + + // output cache grows by factor of h_split * k_split = 4 allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // ko.outer = outer loop split on ko using k_split factor + for (ko.outer: int32, 0, 2) { + // ho.outer = outer loop split on ho using h_split factor + for (ho.outer: int32, 0, 4) { + + // input cache read + // NHWC -> NHWC8h8w32c (pending RFC) + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } + } // end ho.inner - // loop split ho.outer vs. ho.inner based on h_split parameter - for (ho.outer: int32, 0, 4) { - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { + // filter cache read + for (ko.inner: int32, 0, 2) { for (co: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + for (ci8: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (ci4: int32, 0, 4) { + filter.cache[(((((ko.inner*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] = + (float32*)filter_pointer[((((((ko.outer*4096) + (ko.inner*2048)) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] } } } } - } - } - for (ho.c.inner: int32, 0, 2) { - for (wo.c: int32, 0, 8) { - for (ko.c: int32, 0, 2) { - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } // end ko.inner + + // compute + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + + // init output cache + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } } } - } - for (rc.outer: int32, 0, 2) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (ki.c: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + + // convolution + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = ( - (float32*)input.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * - (float32*)kernel_pointer[(((((ko.c*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * + (float32*)filter.cache[(((((ko.c.inner*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) ) - ) + } } } } } - } - } - } - } - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - for (ko: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - output_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] = - (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] + } // end wo.c + } // end ho.c.inner + } // end ko.c.inner + + // cache write + for (ko.inner: int32, 0, 2) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + } } } } - } - } - } - } + } // end ho.inner + } // end ko.inner + } // end ho.outer + } // end ko.outer } ``` # 3x3 conv2d (no padding) -Change from a 1x1 kernel to a 3x3 kernel. The implication of this change is that `h_split + 1` rather than just `h_split` "full width" slices of the input are required to compute the output. This is due to the fact that the 3x3 kernel will "fall off the bottom" of the input and thus the vertically adjacent "full width" slice must be prefetched into the input cache. +Change from a 1x1 filter to a 3x3 filter. The implication of this change is that `h_split + 1` rather than just `h_split` "full width" slices of the input are required to compute the output. This is due to the fact that the 3x3 filter will "fall off the bottom" of the input and thus the vertically adjacent "full width" slice must be prefetched into the input cache. The key changes in TIR versus the above are... 1) Increased input cache size to hold the vertically adjacent slice ``` + // input cache grows to hold vertically adjacent slice allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; ``` @@ -298,19 +355,33 @@ The key changes in TIR versus the above are... The `if` statement above indicates NOT to prefetch the vertically adjacent slice at the "bottom" of the input since it does not exist. + +3) Increased filter cache size to hold 3x3 filter + +``` + // filter cache grows to hold larger 3x3 filter + allocate(filter.cache: Pointer(global float32), float32, [36864]), storage_scope = global; +``` + +4) Loops over `rh` and `rw` the kernel spatial dimensions: +``` + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { +``` + ## Command -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-3-1-0-float32-2-1-64-64-64-llvm]" +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-3-1-0-float32-2-2-1-64-64-128-llvm]" ## Parameters | Parameter | Value | | --------- | ----------- | | Batch | 1 | -| Kernel | 3x3 | +| Filter | 3x3 | | Spatial | 64x64 | | Input Ch | 64 | -| Output Ch | 64 | +| Output Ch | 128 | | Stride | 1 | | Padding | 0 | | Layout | NHWC8h8w32c | @@ -318,12 +389,10 @@ pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2d ## Assumptions -Same as above +* n/a ## To Do -Same as above, and ... - There may be some opportunity to optimize cache reuse in this case. Consider the loops over `ho.outer` and `ho.inner` and the index calculation `ho.outer * 64k + ho.inner * 32k` into the input pointer: | ho.outer | ho.inner | ho.outer * 64k + ho.inner * 32k | @@ -346,86 +415,103 @@ Noe that the vertically adjacent slice in loop N (i.e. the loop where `ho.outer` ## Annotated TIR ``` -primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> () +primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} - buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 2, 8, 8, 32], []), - kernel_buffer: Buffer(kernel_pointer: Pointer(float32), float32, [2, 2, 3, 3, 8, 32, 4], []), - input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} - buffer_map = {input_handle: input_buffer, kernel_handle: kernel_buffer, output_handle: output_buffer} { - - // increased input cache size to hold vertically adjacent slice + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c + filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 3, 3, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) + buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { + // input cache grows to hold vertically adjacent slice allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; + // filter cache grows to hold larger 3x3 filter + allocate(filter.cache: Pointer(global float32), float32, [36864]), storage_scope = global; allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; - for (ho.outer: int32, 0, 4) { - - // iterate over h_split + 1 = 3 input slices - for (ho.inner: int32, 0, 3) { - - // don't prefetch the vertically adjacent slice at the "bottom" of the input - if (((ho.outer*2) + ho.inner) < 8) { - for (wo: int32, 0, 8) { - for (co: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + for (ko.outer: int32, 0, 2) { + for (ho.outer: int32, 0, 4) { + // input cache read + // NHWC -> NHWC8h8w32c (pending RFC) + for (ho.inner: int32, 0, 3) { + if (((ho.outer*2) + ho.inner) < 8) { + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } } } } } } } - } - for (ho.c.inner: int32, 0, 2) { - for (wo.c: int32, 0, 8) { - for (ko.c: int32, 0, 2) { - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + // filter cache read + for (ko.inner: int32, 0, 2) { + for (co: int32, 0, 2) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (ci8: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (ci4: int32, 0, 4) { + filter.cache[(((((((ko.inner*18432) + (co*9216)) + (rh*3072)) + (rw*1024)) + (ci8*128)) + (ki*4)) + ci4)] = + (float32*)filter_pointer[((((((((ko.outer*36864) + (ko.inner*18432)) + (co*9216)) + (rh*3072)) + (rw*1024)) + (ci8*128)) + (ki*4)) + ci4)] + } + } + } + } // end rw + } // end rh + } + } + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } } } - } - for (rc.outer: int32, 0, 2) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (rh: int32, 0, 3) { - for (rw: int32, 0, 3) { - for (ki.c: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = ( - (float32*)input.cache[((((((((floordiv((hi.c + rh), 8)*32768) + (ho.c.inner*32768)) + (floordiv((wi.c + rw), 8)*4096)) + (wo.c*4096)) + (rc.outer*2048)) + (floormod((hi.c + rh), 8)*256)) + (floormod((wi.c + rw), 8)*32)) + rc.inner)] * - (float32*)kernel_pointer[(((((((ko.c*18432) + (rc.outer*9216)) + (rh*3072)) + (rw*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[((((((((floordiv((hi.c + rh), 8)*32768) + (ho.c.inner*32768)) + (floordiv((wi.c + rw), 8)*4096)) + (wo.c*4096)) + (rc.outer*2048)) + (floormod((hi.c + rh), 8)*256)) + (floormod((wi.c + rw), 8)*32)) + rc.inner)] * + (float32*)filter.cache[(((((((ko.c.inner*18432) + (rc.outer*9216)) + (rh*3072)) + (rw*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) ) - ) + } } - } - } + } // end rw + } // end rh } } } - } - } - } - } - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - for (ko: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - output_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] = - (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] + } // end wo.c + } // end ho.c.inner + } // end ko.c.inner + for (ko.inner: int32, 0, 2) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + } } } } - } - } - } - } -} -``` \ No newline at end of file + } // end ho.inner + } // end ko.inner + } // end ho.outer + } // end ko.outer +}``` \ No newline at end of file diff --git a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py b/tests/python/contrib/test_hexagon/test_conv2d_blocked.py index 37a623b613f8..1304d341eda2 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py +++ b/tests/python/contrib/test_hexagon/test_conv2d_blocked.py @@ -162,6 +162,7 @@ def conv2d_packed_filter( stride, padding, dtype, + k_split_factor, h_split_factor, storage_scope="global", ): @@ -263,6 +264,7 @@ def compute(n, ho, wo, ko, hi, wi, ki): # cache read for the input / activation (X) Xl = s.cache_read(X_packed, storage_scope, [Y]) + Fl = s.cache_read(filt_packed, storage_scope, [Y]) # cache write for the output (Y) Yl = s.cache_write(Y, storage_scope) @@ -277,7 +279,9 @@ def compute(n, ho, wo, ko, hi, wi, ki): # loop split h and compute cache write at outer loop split # to increase cache usage by factor of h_split_factor + koo, koi = s[Y].split(ko, factor=k_split_factor) hoo, hoi = s[Y].split(ho, factor=h_split_factor) + s[Y].reorder(n, koo, hoo, koi, hoi, wo, hi, wi, ki) s[Yl].compute_at(s[Y], hoo) #################### @@ -297,9 +301,11 @@ def compute(n, ho, wo, ko, hi, wi, ki): # loop split h and compute cache write at outer loop split # to increase cache usage by factor of h_split_factor + koo, koi = s[Yl].split(ko, factor=k_split_factor) hoo, hoi = s[Yl].split(ho, factor=h_split_factor) - s[Yl].reorder(n, hoo, hoi, wo, ko, rco, hi, wi, ki, rci) + s[Yl].reorder(n, koo, hoo, koi, hoi, wo, rco, hi, wi, ki, rci) s[Xl].compute_at(s[Yl], hoo) + s[Fl].compute_at(s[Yl], hoo) binds = {} if storage_scope and storage_scope != "global": @@ -318,6 +324,7 @@ def conv2d_packed_filter_nhwhwc( stride, padding, dtype, + k_split_factor, h_split_factor, storage_scope="global", ): @@ -406,6 +413,7 @@ def compute(n, ho, wo, hi, wi, k): # cache read for the input / activation (X) Xl = s.cache_read(X_packed, storage_scope, [Y]) + Fl = s.cache_read(filt_packed, storage_scope, [Y]) # cache write for the output (Y) Yl = s.cache_write(Y, storage_scope) @@ -423,8 +431,9 @@ def compute(n, ho, wo, hi, wi, k): # loop split h and compute cache write at outer loop split # to increase cache usage by factor of h_split_factor + koo, koi = s[Y].split(ko, factor=k_split_factor) hoo, hoi = s[Y].split(ho, factor=h_split_factor) - s[Y].reorder(n, hoo, hoi, wo, ko, hi, wi, ki) + s[Y].reorder(n, koo, hoo, koi, hoi, wo, hi, wi, ki) s[Yl].compute_at(s[Y], hoo) #################### @@ -445,9 +454,11 @@ def compute(n, ho, wo, hi, wi, k): # loop split h and compute cache write at outer loop split # to increase cache usage by factor of h_split_factor + koo, koi = s[Yl].split(ko, factor=k_split_factor) hoo, hoi = s[Yl].split(ho, factor=h_split_factor) - s[Yl].reorder(n, hoo, hoi, wo, ko, rco, hi, wi, ki, rci) + s[Yl].reorder(n, koo, hoo, koi, hoi, wo, rco, hi, wi, ki, rci) s[Xl].compute_at(s[Yl], hoo) + s[Fl].compute_at(s[Yl], hoo) ####################### # cache read schedule # @@ -474,12 +485,13 @@ def compute(n, ho, wo, hi, wi, k): class BaseConv2d: batch = tvm.testing.parameter(1) in_size = tvm.testing.parameter(8, 56, 64) - in_channel = tvm.testing.parameter(64) - out_channel = tvm.testing.parameter(64) + in_channel = tvm.testing.parameter(64, 128) + out_channel = tvm.testing.parameter(64, 128) kernel = tvm.testing.parameter(1, 3) stride = tvm.testing.parameter(1) pad = tvm.testing.parameter(0, 1) dtype = tvm.testing.parameter("float32") + k_split_factor = tvm.testing.parameter(1, 2) h_split_factor = tvm.testing.parameter(1, 2) @@ -504,7 +516,30 @@ def test_conv2d(self, shape_nhwc, shape_oihw, kernel, stride, pad, dtype, target padding=(pad, pad, pad, pad), dtype=dtype, ) - return output, ref_output + + # nhwc8h8w32c -> nhwc + output = output.transpose(0, 1, 4, 2, 5, 3, 6).reshape( + output.shape[0], + output.shape[1] * output.shape[4], + output.shape[2] * output.shape[5], + output.shape[3] * output.shape[6], + ) + + # slice output to match ref_output shape + # e.g. 8x8 spatial 3x3 filter = 6x6 ref output + # but still 8x8 output given the blocked layout + output = output[ + 0 : ref_output.shape[0] : 1, + 0 : ref_output.shape[1] : 1, + 0 : ref_output.shape[2] : 1, + 0 : ref_output.shape[3] : 1, + ] + + if "int" in dtype: + tol = {"atol": 0, "rtol": 0} + elif dtype == "float32": + tol = {"rtol": 1e-4, "atol": 2e-4} + tvm.testing.assert_allclose(output, ref_output, **tol) class TestConv2dPackedFilter(BaseConv2d): @@ -522,6 +557,7 @@ def test_conv2d( pad, dtype, target, + k_split_factor, h_split_factor, ): inputs = [ @@ -543,9 +579,45 @@ def test_conv2d( stride=(stride, stride), padding=(pad, pad, pad, pad), dtype=dtype, + k_split_factor=k_split_factor, h_split_factor=h_split_factor, ) - return output, ref_output + + # nhwc8h8w32c + if len(output.shape) == 7: + # nhwc8h8w32c -> nhwc + output = output.transpose(0, 1, 4, 2, 5, 3, 6).reshape( + output.shape[0], + output.shape[1] * output.shape[4], + output.shape[2] * output.shape[5], + output.shape[3] * output.shape[6], + ) + + # nhwhwc + else: + # nhwhwc -> nhwc + output = output.transpose(0, 1, 3, 2, 4, 5).reshape( + output.shape[0], + output.shape[1] * output.shape[3], + output.shape[2] * output.shape[4], + output.shape[5], + ) + + # slice output to match ref_output shape + # e.g. 8x8 spatial 3x3 filter = 6x6 ref output + # but still 8x8 output given the blocked layout + output = output[ + 0 : ref_output.shape[0] : 1, + 0 : ref_output.shape[1] : 1, + 0 : ref_output.shape[2] : 1, + 0 : ref_output.shape[3] : 1, + ] + + if "int" in dtype: + tol = {"atol": 0, "rtol": 0} + elif dtype == "float32": + tol = {"rtol": 1e-4, "atol": 2e-4} + tvm.testing.assert_allclose(output, ref_output, **tol) if __name__ == "__main__": From 3f064b617fb229b6152ca4c2ef5e1baa9064c397 Mon Sep 17 00:00:00 2001 From: ziyu-guo <40365354+ziyu-guo@users.noreply.github.com> Date: Wed, 20 Oct 2021 10:10:49 -0700 Subject: [PATCH 59/84] Add conv1d support in BYOC TRT by converting conv1d to conv2d (#9324) Co-authored-by: ziyu.guo --- python/tvm/relay/op/contrib/tensorrt.py | 19 ++++++++ src/runtime/contrib/tensorrt/tensorrt_ops.cc | 50 ++++++++++++++++++++ tests/python/contrib/test_tensorrt.py | 28 +++++++++++ 3 files changed, 97 insertions(+) diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index cec7c4d141cb..03bb273c8f92 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -142,6 +142,7 @@ def partition_for_tensorrt( transform.RemoveUnusedFunctions(), transform.ConvertLayout( { + "nn.conv1d": ["NCW", "default"], "nn.conv2d": ["NCHW", "default"], "nn.conv3d": ["NCDHW", "default"], "nn.conv2d_transpose": ["NCHW", "default"], @@ -374,6 +375,23 @@ def softmax_annotate_fn(expr): # pylint: disable=unused-variable return True +@_register_external_dynamic_check_func("nn.conv1d") +def conv1d_annotate_fn(expr): # pylint: disable=unused-variable + """Check if nn.conv1d is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.data_layout != "NCW": + logger.info("nn.conv1d: data_layout is %s but must be NCW.", attrs.data_layout) + return False + if attrs.kernel_layout != "OIW": + logger.info("nn.conv1d: kernel_layout is %s but must be OIW.", attrs.kernel_layout) + return False + return True + + @_register_external_dynamic_check_func("nn.conv2d") def conv2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.conv2d is supported by TensorRT.""" @@ -912,6 +930,7 @@ def __init__(self): def visit_call(self, call): compute_intensive_ops = set( [ + "nn.conv1d", "nn.conv2d", "nn.conv2d_transpose", "nn.conv3d", diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 94bbae1559d9..a27fe1114af9 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -226,6 +226,55 @@ class ElementWiseBinaryOpConverter : public TensorRTOpConverter { } }; +class Conv1DOpConverter : public TensorRTOpConverter { + public: + Conv1DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + auto weight_shape = params->inputs.at(1).weight_shape; + ICHECK_EQ(params->node.GetAttr>("data_layout")[0], "NCW"); + ICHECK_EQ(params->node.GetAttr>("kernel_layout")[0], "OIW"); + auto str_strides = params->node.GetAttr>("strides"); + auto str_dilation = params->node.GetAttr>("dilation"); + auto str_padding = params->node.GetAttr>("padding"); + int groups = std::stoi(params->node.GetAttr>("groups")[0]); + int channels = weight_shape[0]; + if (params->node.HasAttr("channels") && + !params->node.GetAttr>("channels")[0].empty()) { + channels = std::stoi(params->node.GetAttr>("channels")[0]); + } + + auto shuffle_layer = params->network->addShuffle(*input_tensor); + std::vector new_shape = {input_dims[0], input_dims[1], 1}; + shuffle_layer->setReshapeDimensions(VectorToTrtDims(new_shape)); + input_tensor = shuffle_layer->getOutput(0); + + const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], 1); + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + + auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size, + params->inputs.at(1).weight, bias); + ICHECK(conv_layer != nullptr); + conv_layer->setPadding(nvinfer1::DimsHW(std::stoi(str_padding[0]), 0)); + ICHECK_EQ(str_strides.size(), 1); + const auto strides = nvinfer1::DimsHW(std::stoi(str_strides[0]), 1); + conv_layer->setStride(strides); + ICHECK_EQ(str_dilation.size(), 1); + const auto dilation = nvinfer1::DimsHW(std::stoi(str_dilation[0]), 1); + conv_layer->setDilation(dilation); + conv_layer->setNbGroups(groups); + input_tensor = conv_layer->getOutput(0); + + auto conv_output_dims = TrtDimsToVector(input_tensor->getDimensions()); + std::vector back_shape = {0, 0}; + auto shuffle_back_layer = params->network->addShuffle(*input_tensor); + shuffle_back_layer->setReshapeDimensions(VectorToTrtDims(back_shape)); + params->outputs.push_back(shuffle_back_layer->getOutput(0)); + } +}; + class Conv2DOpConverter : public TensorRTOpConverter { public: Conv2DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} @@ -1198,6 +1247,7 @@ GetOpConverters() { map->emplace("nn.batch_norm", std::make_shared()); map->emplace("nn.layer_norm", std::make_shared()); map->emplace("nn.softmax", std::make_shared()); + map->emplace("nn.conv1d", std::make_shared()); map->emplace("nn.conv2d", std::make_shared()); map->emplace("nn.dense", std::make_shared()); map->emplace("nn.bias_add", std::make_shared()); diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index ec512d7d714f..df4234e7e605 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -355,6 +355,34 @@ def load_vm(): assert_result_dict_holds(result_dict) +def test_conv1d(run_module): + def get_graph( + x_shape=((1, 3, 224)), + k_shape=(10, 3, 3), + groups=1, + padding=(1, 1), + strides=(1), + dilation=(1), + channels=None, + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + out = relay.nn.conv1d( + x, + kernel, + kernel_size=k_shape[2:3], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + channels=channels, + ) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + run_and_verify_func(get_graph(channels=10), run_module=run_module) + + def test_conv2d(run_module): def get_graph( x_shape=(1, 32, 8, 8), From 9cf0245adb4cd61d07c278d7b7bcd01c9c5b9f8d Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 21 Oct 2021 07:19:36 +0900 Subject: [PATCH 60/84] [Relay, TOPI] Add searchsorted op (#9184) * Add relay definition * 1D cpu test working * multi dim working * gpu version working * check shape in type rel * support side * use target specfic max threads * add relay boilerplate * relay test working * cleanup topi test * fix test * add torch converter * handle other cases * more topi test * support torch bucketize * update doc * fix tests * fix lint * rebase fix * make the test case smaller * add tests for edge cases * replace "side" attribute with boolean "right" * add more descrition to binear_search IR gen params * return index from binary_search rather than update inplace * remove unused argument * format fix --- include/tvm/relay/attrs/algorithm.h | 16 +++ python/tvm/relay/frontend/pytorch.py | 22 +++ python/tvm/relay/op/_algorithm.py | 4 + python/tvm/relay/op/algorithm.py | 34 +++++ python/tvm/relay/op/op_attrs.py | 5 + python/tvm/relay/op/strategy/cuda.py | 12 ++ python/tvm/relay/op/strategy/generic.py | 25 ++++ python/tvm/topi/__init__.py | 1 + python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/searchsorted.py | 102 ++++++++++++++ python/tvm/topi/searchsorted.py | 127 ++++++++++++++++++ python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/searchsorted.py | 35 +++++ src/relay/op/algorithm/searchsorted.cc | 86 ++++++++++++ src/te/operation/create_primfunc.cc | 2 +- tests/python/frontend/pytorch/test_forward.py | 30 +++++ tests/python/relay/test_op_level5.py | 1 - tests/python/relay/test_op_level6.py | 24 ++++ .../topi/python/test_topi_searchsorted.py | 93 +++++++++++++ 19 files changed, 619 insertions(+), 2 deletions(-) create mode 100644 python/tvm/topi/cuda/searchsorted.py create mode 100644 python/tvm/topi/searchsorted.py create mode 100644 python/tvm/topi/testing/searchsorted.py create mode 100644 src/relay/op/algorithm/searchsorted.cc create mode 100644 tests/python/topi/python/test_topi_searchsorted.py diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 83b4ddaead43..3652a09e9168 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -76,6 +76,22 @@ struct TopKAttrs : public tvm::AttrsNode { } }; +struct SearchSortedAttrs : public tvm::AttrsNode { + bool right; + DataType dtype; + + TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs") { + TVM_ATTR_FIELD(right).set_default(false).describe( + "Controls which index is returned if a value lands exactly on one of sorted values. If " + " false, the index of the first suitable location found is given. If true, return the " + "last such index. If there is no suitable index, return either 0 or N (where N is the " + "size of the innermost dimension)."); + TVM_ATTR_FIELD(dtype) + .set_default(DataType::Int(32)) + .describe("Data type of the output indices."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ALGORITHM_H_ diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 76cd0455661b..3fc202a7cc91 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2774,6 +2774,26 @@ def all_any_common(self, op, inputs, input_types): inp = inputs[0] return op(inp, axis=dim, keepdims=keepdim) + def searchsorted_common(self, sorted_sequence, values, out_int32, right): + dtype = "int32" if out_int32 else "int64" + values_shape = _infer_shape(values) + + if len(values_shape) == 0: + values = _op.expand_dims(values, 0) + + out = _op.searchsorted(sorted_sequence, values, right=right, dtype=dtype) + + if len(values_shape) == 0: + return _op.squeeze(out) + + return out + + def searchsorted(self, inputs, input_types): + return self.searchsorted_common(*inputs) + + def bucketize(self, inputs, input_types): + return self.searchsorted_common(inputs[1], inputs[0], inputs[2], inputs[3]) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2999,6 +3019,8 @@ def create_convert_map(self): "aten::lstm": self.lstm, "aten::all": functools.partial(self.all_any_common, _op.all), "aten::any": functools.partial(self.all_any_common, _op.any), + "aten::searchsorted": self.searchsorted, + "aten::bucketize": self.bucketize, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 817f96b696df..19162a108395 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -41,6 +41,10 @@ register_strategy("topk", strategy.topk_strategy) register_pattern("topk", OpPattern.OPAQUE) +# searchsorted +register_strategy("searchsorted", strategy.searchsorted_strategy) +register_pattern("searchsorted", OpPattern.OPAQUE) + @script def _topk_shape_func_input_shape(data_shape, k, axis): diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 119936f632f8..809a9061ade0 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -115,3 +115,37 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): if ret_type == "both": return TupleWrapper(out, 2) return out + + +def searchsorted(sorted_sequence, values, right=False, dtype="int32"): + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + Parameters + ---------- + sorted_sequence : relay.Expr + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : relay.Expr + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False, the index of the first suitable location found is given. If true, return the + last such index. If there is no suitable index, return either 0 or N (where N is the + size of the innermost dimension). + + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : relay.Expr + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ + return _make.searchsorted(sorted_sequence, values, right, dtype) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 8fd46817b817..dba40b2f6f34 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -564,6 +564,11 @@ class TopkAttrs(Attrs): """Attributes used in topk operators""" +@tvm._ffi.register_object("relay.attrs.SearchSortedAttrs") +class SearchSortedAttrs(Attrs): + """Attributes used in searchsorted operators""" + + @tvm._ffi.register_object("relay.attrs.TupleGetItemAttrs") class TupleGetItemAttrs(Attrs): """Attributes used in tuple item access operators""" diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index da7cbd5cec10..5f24dbda9d35 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1022,6 +1022,18 @@ def topk_strategy_cuda(attrs, inputs, out_type, target): return strategy +@searchsorted_strategy.register(["cuda", "gpu"]) +def searchsorted_strategy_cuda(attrs, inputs, out_type, target): + """searchsorted cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_searchsorted(topi.cuda.searchsorted), + wrap_topi_schedule(topi.cuda.schedule_extern), + name="searchsorted.cuda", + ) + return strategy + + @multibox_prior_strategy.register(["cuda", "gpu"]) def multibox_prior_strategy_cuda(attrs, inputs, out_type, target): """multibox_prior cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index d021b5d9d84d..777f17ba6084 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1002,6 +1002,31 @@ def topk_strategy(attrs, inputs, out_type, target): return strategy +# searchsorted +def wrap_compute_searchsorted(topi_compute): + """Wrap searchsorted compute""" + + def _compute_searchsorted(attrs, inputs, out_type): + right = attrs.right + dtype = attrs.dtype + return [topi_compute(inputs[0], inputs[1], right, dtype)] + + return _compute_searchsorted + + +# searchsorted_strategy +@override_native_generic_func("searchsorted_strategy") +def searchsorted_strategy(attrs, inputs, out_type, target): + """searchsorted generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_searchsorted(topi.searchsorted), + wrap_topi_schedule(topi.generic.schedule_extern), + name="searchsorted.generic", + ) + return strategy + + # multibox_prior def wrap_compute_multibox_prior(topi_compute): """Wrap multibox_prior compute""" diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 6b22cf13f5b9..e243d6ee3bc7 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -45,6 +45,7 @@ from .scan import * from .einsum import * from .unique import * +from .searchsorted import * from . import generic from . import nn from . import x86 diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 21ddf57ca1d0..88d306761310 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -59,3 +59,4 @@ from .sparse_reshape import * from .transform import * from .unique import * +from .searchsorted import * diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py new file mode 100644 index 000000000000..1c39ccaa8632 --- /dev/null +++ b/python/tvm/topi/cuda/searchsorted.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""searchsorted operator for GPU""" +import tvm +from tvm import te +from .. import utils +from ..searchsorted import binary_search + + +def searchsorted(sorted_sequence, values, right, out_dtype="int64"): + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + Parameters + ---------- + sorted_sequence : te.Tensor + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : te.Tensor + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False, the index of the first suitable location found is given. If true, return the + last such index. If there is no suitable index, return either 0 or N (where N is the + size of the innermost dimension). + + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : te.Tensor + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ + + def ir(sorted_sequence, values, indices): + ib = tvm.tir.ir_builder.create() + sorted_sequence_shape = sorted_sequence.shape + values_shape = values.shape + num_search = utils.prod(values_shape) + search_range = sorted_sequence_shape[-1] + + sorted_sequence = ib.buffer_ptr(sorted_sequence) + values = ib.buffer_ptr(values) + indices = ib.buffer_ptr(indices) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr( + bx, "thread_extent", tvm.tir.indexdiv(num_search + max_threads - 1, max_threads) + ) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.if_scope(tid < num_search): + if len(sorted_sequence_shape) == 1: + sequence_offset = 0 + else: + sequence_id = tid // values_shape[-1] + sequence_offset = sequence_id * search_range + + indices[tid] = binary_search( + ib, + sequence_offset, + search_range, + sorted_sequence, + values[tid], + right, + out_dtype, + ) + + return ib.get() + + return te.extern( + values.shape, + [sorted_sequence, values], + lambda ins, outs: ir(ins[0], ins[1], outs[0]), + name="searchsorted", + dtype=out_dtype, + ) diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py new file mode 100644 index 000000000000..28ffd170c955 --- /dev/null +++ b/python/tvm/topi/searchsorted.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""searchsorted operator""" +from . import utils +from . import te +from ..tir import ir_builder +from .math import cast + + +def binary_search(ib, sequence_offset, search_range, sorted_sequence, value, right, out_dtype): + """Common IR generator for binary search used by CPU and GPU backends. + + `sorted_sequence` is a N-D Buffer whose innermost dimension we want to search for `value`, + and `search_range` is the size of the innermost dimension. `sequence_offset` is + a 1-D linearlized offset specifying which of innermost sequences to search. + + So the search for `value` is performed over + `sorted_sequence[sequence_offset:(sequence_offset + search_range)]`. + Note that we index N-D Buffer by 1-D linearlized indices. + + """ + lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") + hi = ib.allocate(out_dtype, (1,), name="hi", scope="local") + + lo[0] = cast(0, out_dtype) + hi[0] = cast(search_range, out_dtype) + + # Reference: pytorch/aten/src/ATen/native/cuda/Bucketization.cu + def condition(current_val, target_val): + if right: + return current_val <= target_val + return current_val < target_val + + with ib.while_loop(lo[0] < hi[0]): + mid = lo[0] + (hi[0] - lo[0] >> 1) + with ib.if_scope(condition(sorted_sequence[sequence_offset + mid], value)): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + return lo[0] + + +def searchsorted(sorted_sequence, values, right=False, out_dtype="int64"): + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + Parameters + ---------- + sorted_sequence : te.Tensor + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : te.Tensor + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False, the index of the first suitable location found is given. If true, return the + last such index. If there is no suitable index, return either 0 or N (where N is the + size of the innermost dimension). + + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : te.Tensor + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ + + def ir(sorted_sequence, values, indices): + ib = ir_builder.create() + sorted_sequence_shape = sorted_sequence.shape + values_shape = values.shape + num_search = utils.prod(values_shape) + search_range = sorted_sequence_shape[-1] + + sorted_sequence = ib.buffer_ptr(sorted_sequence) + values = ib.buffer_ptr(values) + indices = ib.buffer_ptr(indices) + + with ib.for_range(0, num_search, name="i", kind="parallel") as i: + if len(sorted_sequence_shape) == 1: + sequence_offset = 0 + else: + sequence_id = i // values_shape[-1] + sequence_offset = sequence_id * search_range + + indices[i] = binary_search( + ib, + sequence_offset, + search_range, + sorted_sequence, + values[i], + right, + out_dtype, + ) + + return ib.get() + + return te.extern( + values.shape, + [sorted_sequence, values], + lambda ins, outs: ir(ins[0], ins[1], outs[0]), + name="searchsorted", + dtype=out_dtype, + ) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index d10c49f5c084..2d7d0a4b9e11 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -73,3 +73,4 @@ from .batch_to_space_nd import batch_to_space_nd_python from .nll_loss import nll_loss from .dense import dense +from .searchsorted import searchsorted_ref diff --git a/python/tvm/topi/testing/searchsorted.py b/python/tvm/topi/testing/searchsorted.py new file mode 100644 index 000000000000..10762600992d --- /dev/null +++ b/python/tvm/topi/testing/searchsorted.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The reference implementation of searchsorted in Numpy.""" +import numpy as np + + +def searchsorted_ref(sorted_sequence, values, right, out_dtype): + """Run Numpy searchsorted on 1-D or N-D sorted_sequence.""" + side = "right" if right else "left" + if len(sorted_sequence.shape) == 1 and len(values.shape) > 1: + sorted_sequence_2d = np.tile(sorted_sequence, (np.prod(values.shape[:-1]), 1)) + else: + sorted_sequence_2d = np.reshape(sorted_sequence, (-1, sorted_sequence.shape[-1])) + + values_2d = np.reshape(values, (-1, values.shape[-1])) + indices = np.zeros(values_2d.shape, dtype=out_dtype) + + for i in range(indices.shape[0]): + indices[i] = np.searchsorted(sorted_sequence_2d[i], values_2d[i], side=side) + + return np.reshape(indices, values.shape) diff --git a/src/relay/op/algorithm/searchsorted.cc b/src/relay/op/algorithm/searchsorted.cc new file mode 100644 index 000000000000..be5921311660 --- /dev/null +++ b/src/relay/op/algorithm/searchsorted.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file searchsorted.cc + * \brief SearchSorted operators + */ +#include +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(SearchSortedAttrs); + +bool SearchSortedRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const SearchSortedAttrs* param = attrs.as(); + ICHECK_EQ(types.size(), 3); + const auto* sorted_sequence = types[0].as(); + const auto* values = types[1].as(); + ICHECK(sorted_sequence) << "Expects TensorType in the first input"; + ICHECK(values) << "Expects TensorType in the second input"; + ICHECK_GT(values->shape.size(), 0) << "The rank of `values` must be greater than one"; + + if (sorted_sequence->shape.size() > 1) { + ICHECK_EQ(sorted_sequence->shape.size(), values->shape.size()) + << "Ranks of `sorted_sequence` and values must be the same if `sorted_sequence` is " + "multi-dimensional."; + + for (size_t i = 0; i < values->shape.size() - 1; ++i) { + if (sorted_sequence->shape[i].as() && values->shape[i].as()) { + ICHECK_EQ(sorted_sequence->shape[i].as()->value, + values->shape[i].as()->value) + << "`sorted_sequence and `values` do not have the same shape along outer axes"; + } + } + } + + reporter->Assign(types[2], TensorType(values->shape, param->dtype)); + return true; +} + +Expr MakeSearchSorted(Expr sorted_sequence, Expr values, Bool right, DataType dtype) { + auto attrs = make_object(); + static const Op& op = Op::Get("searchsorted"); + attrs->dtype = dtype; + attrs->right = right; + return Call(op, {sorted_sequence, values}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.searchsorted").set_body_typed(MakeSearchSorted); + +RELAY_REGISTER_OP("searchsorted") + .describe( + R"doc(Find indices where elements should be inserted to maintain order. +If `sorted_sequence` is N-dimensional, the innermost dimension of +`values` are searched in the corresponding dimension of `sorted_sequence`. +)doc" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("sorted_sequence", "Tensor", + "Monotonically increasing sequence on the innermost dimension.") + .add_argument("values", "Tensor", "Values to search for.") + .set_support_level(6) + .add_type_rel("SearchSorted", SearchSortedRel); + +} // namespace relay +} // namespace tvm diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index aa164b03a2a7..657dc121961c 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -48,7 +48,7 @@ class ProducerToBufferTransformer : public StmtExprMutator { const std::unordered_map& tensor2buffers_; }; -/*! \brief Helper data structural to store informations. */ +/*! \brief Helper data structure to store information. */ struct CreateFuncInfo { /*! \brief The Tensor arg_list. */ Array arg_list; diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3a3889d5cfb7..0031f4143fab 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3962,5 +3962,35 @@ def test_fn(f, dim=None, keepdim=False): verify_model(test_fn(f, 0, keepdim=True), [torch.rand(4, 2).bool()]) +@tvm.testing.uses_gpu +def test_searchsorted(): + def test_fn(out_int32=False, right=False): + return lambda x, y: torch.searchsorted(x, y, out_int32=out_int32, right=right) + + sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + verify_model(test_fn(), [sorted_sequence, values]) + verify_model(test_fn(out_int32=True), [sorted_sequence[0], values[0]]) + verify_model(test_fn(right=True), [sorted_sequence, values]) + + sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + values = torch.tensor([[3, 6, 9], [4, 2, 7]]) + verify_model(test_fn(), [sorted_sequence_1d, values]) + + verify_model(test_fn(), [sorted_sequence_1d, torch.tensor(6)]) + + +@tvm.testing.uses_gpu +def test_bucketize(): + def test_fn(out_int32=False, right=False): + return lambda x, y: torch.bucketize(x, y, out_int32=out_int32, right=right) + + boundaries = torch.tensor([1, 3, 5, 7, 9]) + values = torch.tensor([3, 6, 9]) + + verify_model(test_fn(), [values, boundaries]) + verify_model(test_fn(out_int32=True, right=True), [values, boundaries]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index eb4eee379b08..c968c5a7f19f 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -773,7 +773,6 @@ def verify_roi_align( mode=mode, ) for target, dev in tvm.testing.enabled_targets(): - print("test on", target) op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( np_data, np_rois ) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index ea640c62dfeb..48c58dc2dc33 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -20,6 +20,7 @@ import numpy as np import tvm from tvm import relay +from tvm.topi.testing import searchsorted_ref import tvm.testing @@ -149,5 +150,28 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"): verify_topk(k, axis, ret_type, False, "int64", "float16") +@tvm.testing.uses_gpu +def test_searchsorted(): + def verify_searchsorted(right, dtype): + shape = (8, 9, 10) + values_shape = shape[:-1] + (10,) + sorted_sequence = relay.var("sorted_sequence", relay.TensorType(shape, "float32")) + values = relay.var("sorted_sequence", relay.TensorType(values_shape, "float32")) + out = relay.searchsorted(sorted_sequence, values, right, dtype) + func = relay.Function([sorted_sequence, values], out) + sorted_sequence_np = np.sort(np.random.randn(*shape).astype("float32"), axis=-1) + values_np = np.random.randn(*values_shape).astype("float32") + np_indices = searchsorted_ref(sorted_sequence_np, values_np, right, dtype) + + for target, dev in tvm.testing.enabled_targets(): + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + sorted_sequence_np, values_np + ) + np.testing.assert_equal(op_res.numpy(), np_indices) + + verify_searchsorted(False, "int32") + verify_searchsorted(True, "int64") + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py new file mode 100644 index 000000000000..7b3976b7eb74 --- /dev/null +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing +import tvm.topi.testing +from tvm.topi.testing import searchsorted_ref +from tvm import te, topi + +topi_funcs = {"generic": topi.searchsorted, "cuda": topi.cuda.searchsorted} + + +def get_implementations(): + topi_func_generic = topi_funcs["generic"] + topi_func_cuda = topi_funcs["cuda"] + + return { + "generic": ( + lambda x, y, side, out_dtype: topi_func_generic(x, y, side, out_dtype), + topi.generic.schedule_extern, + ), + "cuda": ( + lambda x, y, side, out_dtype: topi_func_cuda(x, y, side, out_dtype), + topi.cuda.schedule_extern, + ), + "vulkan": ( + lambda x, y, side, out_dtype: topi_func_cuda(x, y, side, out_dtype), + topi.cuda.schedule_extern, + ), + } + + +@tvm.testing.parametrize_targets +def test_searchsorted(dev, target): + def verify_with_input(sorted_sequence_np, values_np, right): + sorted_sequence = te.placeholder(sorted_sequence_np.shape, dtype="float32") + values = te.placeholder(values_np.shape, dtype="float32") + out_dtype = "int32" + implementations = get_implementations() + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + + with tvm.target.Target(target): + indices = fcompute(sorted_sequence, values, right, out_dtype) + s = fschedule([indices]) + + func = tvm.build(s, [sorted_sequence, values, indices], target=target) + dev = tvm.device(target, 0) + + a = tvm.nd.array(sorted_sequence_np, dev) + b = tvm.nd.array(values_np, dev) + c = tvm.nd.array(np.zeros(values_np.shape, dtype=indices.dtype), dev) + func(a, b, c) + ref = searchsorted_ref(sorted_sequence_np, values_np, right, out_dtype) + np.testing.assert_equal(c.numpy(), ref) + + def verify(sequence_len, num_search, outer_axes, right, sorted_sequence_1d=False): + if sorted_sequence_1d: + sorted_sequence_shape = (sequence_len,) + else: + sorted_sequence_shape = outer_axes + (sequence_len,) + values_shape = outer_axes + (num_search,) + + verify_with_input( + np.sort(np.random.randn(*sorted_sequence_shape).astype("float32"), axis=-1), + np.random.randn(*values_shape).astype("float32"), + right, + ) + + verify(1024, 1000, (10, 5, 3), False) + verify(999, 2000, (10, 5, 3), True) + verify(1000, 1000, (), False) + verify(2001, 100, (500,), True) + verify(2001, 100, (500,), False, sorted_sequence_1d=True) + + # Check edge cases + for right in [True, False]: + sorted_sequence = np.array([1, 2, 3, 4, 5], dtype="float32") + verify_with_input(sorted_sequence, np.array([6], dtype="float32"), right) + verify_with_input(sorted_sequence, np.array([0], dtype="float32"), right) From 3a5a09d19d3e3233eb78a512ca7f7485884038d2 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 20 Oct 2021 18:39:09 -0700 Subject: [PATCH 61/84] [Error reporting] Replace runtime errors with LOG(FATAL) (#9311) * Replace runtime error with LOG(FATAL) * flaky test --- src/ir/module.cc | 2 +- src/relay/backend/aot_executor_codegen.cc | 18 ++++++++++-------- src/relay/backend/graph_executor_codegen.cc | 16 ++++++++-------- src/relay/transforms/defunctionalization.cc | 4 ++-- .../contrib/arm_compute_lib/acl_runtime.cc | 2 +- 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/ir/module.cc b/src/ir/module.cc index 15c441d61a23..3deb70dd766c 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -170,7 +170,7 @@ Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) } LOG(FATAL) << adt << " does not contain constructor " << cons; - throw std::runtime_error("Constructor Not Found."); + return {}; } tvm::Array IRModuleNode::GetGlobalTypeVars() const { diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 38eb6aa6a07e..56e008a345de 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -440,31 +440,33 @@ class AOTExecutorCodegen : public MixedModeVisitor { void VisitExpr_(const LetNode* op) override { // TODO(giuseros): support Let nodes in AOT - CHECK(false) << "Let not yet implemented in AOT"; + LOG(FATAL) << "Let not yet implemented in AOT"; } void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } void VisitExpr_(const OpNode* op) override { - throw std::runtime_error("can not compile op in non-eta expanded form"); + LOG(FATAL) << "All OpNodes should have been expanded"; + } + void VisitExpr_(const IfNode* op) override { + LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called"; } - void VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); } void VisitExpr_(const FunctionNode* op) override { ICHECK(op->GetAttr(attr::kCompiler).defined()) << "FunctionNode only supported by custom codegen"; } void VisitExpr_(const RefCreateNode* op) override { - throw std::invalid_argument("reference not supported"); + LOG(FATAL) << "AOT executor does not support references (found RefCreateNode)"; } void VisitExpr_(const RefReadNode* op) override { - throw std::invalid_argument("reference not supported"); + LOG(FATAL) << "AOT executor does not support references (found RefReadNode)"; } void VisitExpr_(const RefWriteNode* op) override { - throw std::invalid_argument("reference not supported"); + LOG(FATAL) << "AOT executor does not support references (found RefWriteNode)"; } void VisitExpr_(const ConstructorNode* op) override { - throw std::invalid_argument("ADT constructor case not yet implemented"); + LOG(FATAL) << "AOT executor does not support ADTs (found ConstructorNode)"; } void VisitExpr_(const MatchNode* op) override { - throw std::invalid_argument("match case not yet implemented"); + LOG(FATAL) << "AOT executor does not support matching (found MatchNode)"; } // Create the main PrimFunc to execute the graph. Please note that diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index dbe14b63293f..debd669126c4 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -473,15 +473,15 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorindex]}; } std::vector VisitExpr_(const OpNode* op) override { - throw std::runtime_error("can not compile op in non-eta expanded form"); + LOG(FATAL) << "All OpNodes should have been expanded"; return {}; } std::vector VisitExpr_(const GlobalVarNode* op) override { - throw std::runtime_error(""); + LOG(FATAL) << "All GlobalVarNodes should be removed before graph executor's Codegen is called"; return {}; } std::vector VisitExpr_(const IfNode* op) override { - throw std::invalid_argument("if not supported"); + LOG(FATAL) << "Graph executor does not support control flow (found IfNode)"; return {}; } std::vector VisitExpr_(const FunctionNode* op) override { @@ -490,23 +490,23 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator VisitExpr_(const RefCreateNode* op) override { - throw std::invalid_argument("reference not supported"); + LOG(FATAL) << "Graph executor does not support references (found RefCreateNode)"; return {}; } std::vector VisitExpr_(const RefReadNode* op) override { - throw std::invalid_argument("reference not supported"); + LOG(FATAL) << "Graph executor does not support references (found RefReadNode)"; return {}; } std::vector VisitExpr_(const RefWriteNode* op) override { - throw std::invalid_argument("reference not supported"); + LOG(FATAL) << "Graph executor does not support references (found RefWriteNode)"; return {}; } std::vector VisitExpr_(const ConstructorNode* op) override { - throw std::invalid_argument("ADT constructor case not yet implemented"); + LOG(FATAL) << "Graph executor does not support ADTs (found ConstructorNode)"; return {}; } std::vector VisitExpr_(const MatchNode* op) override { - throw std::invalid_argument("match case not yet implemented"); + LOG(FATAL) << "Graph executor does not support matching (found MatchNode)"; return {}; } /*! diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index 14a86bc8d080..5255a672a856 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -288,8 +288,8 @@ class DefuncMutator : public ExprMutator { return Call(c, call_args); } - - throw std::runtime_error("EncodeArg failed to cast arg into identifier node or function node"); + LOG(FATAL) << "EncodeArg failed to cast arg into identifier node or function node"; + return {}; } /*! diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 5bbc536afaca..a336cf494f4b 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -499,7 +499,7 @@ class ACLRuntime : public JSONRuntimeBase { layer->outputs.push_back( MakeACLTensorFromJSONNode(node, &node.GetInputs()[6], &node.GetInputs()[7])); } else { - throw std::runtime_error("Unsupported form of add op: " + op_name); + LOG(FATAL) << "Unsupported form of add op: " + op_name; } auto f = std::make_shared(); From 4f24921161732a51da6bf6fb08762fe2ec37114d Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 21 Oct 2021 11:13:47 +0900 Subject: [PATCH 62/84] Use variable in curl download url (#9330) * Use variable for curl download url * Replace qemu-5.1.0.tar.xz.sig with ${QEMU_SIG_FILE} --- docker/install/ubuntu_install_qemu.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/install/ubuntu_install_qemu.sh b/docker/install/ubuntu_install_qemu.sh index 1189f2bb8dd4..6682795b0fd8 100755 --- a/docker/install/ubuntu_install_qemu.sh +++ b/docker/install/ubuntu_install_qemu.sh @@ -54,7 +54,7 @@ apt update apt-get -y build-dep qemu gpg --keyserver keyserver.ubuntu.com --recv-keys 0x3353C9CEF108B584 -cat <qemu-5.1.0.tar.xz.sig +cat <${QEMU_SIG_FILE} -----BEGIN PGP ARMORED FILE----- Comment: Use "gpg --dearmor" for unpacking @@ -68,7 +68,7 @@ p5ez/+2k4VAIwIQoP5DoO06waLBffvLIAdPPKYsx71K67OoGG2svc7duC/+5qf1x =hCS7 -----END PGP ARMORED FILE----- EOF -curl -OLs https://download.qemu.org/qemu-5.1.0.tar.xz +curl -OLs https://download.qemu.org/${QEMU_TAR_FILE} gpg --verify ${QEMU_SIG_FILE} tar -xf ${QEMU_TAR_FILE} From f4c146ca37c061a1192fd3aaa988ae23ed1bed67 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 20 Oct 2021 21:42:15 -0700 Subject: [PATCH 63/84] [Relay] Remove FTVMCompute from TNonComputational ops (#9334) * remove FTVMCompute from noncomputational ops * Remove injective schedule registration for on_device since it is non-computational * lint --- python/tvm/relay/op/_tensor.py | 3 --- src/relay/op/annotation/annotation.cc | 7 +------ src/relay/op/memory/memory.cc | 21 +++------------------ src/relay/op/vm/vm.cc | 14 ++------------ 4 files changed, 6 insertions(+), 39 deletions(-) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 18ce93322f43..daec488bbb94 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -89,9 +89,6 @@ register_broadcast_schedule("fast_exp") register_broadcast_schedule("fast_tanh") register_broadcast_schedule("fast_erf") -# a fake on_device schedule. -# this will not be used in actual computation -register_injective_schedule("on_device") # zeros diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index beadf4a67ddc..8b00839cda33 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -94,12 +94,7 @@ RELAY_REGISTER_OP("on_device") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("TNonComputational", true) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_type) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_attr("TNonComputational", true); OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { if (call_node->op == OnDeviceOp()) { diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 5339d48e3a2f..6b22cfd6bdba 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -91,12 +91,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("TNonComputational", true) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype, Array assert_shape) { @@ -206,12 +201,7 @@ RELAY_REGISTER_OP("memory.alloc_tensor") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("TNonComputational", true) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); bool KillRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -230,12 +220,7 @@ RELAY_REGISTER_OP("memory.kill") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("TNonComputational", true) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); static void FlattenTupleTypeAux(const Type& type, std::vector* out) { if (auto tt = type.as()) { diff --git a/src/relay/op/vm/vm.cc b/src/relay/op/vm/vm.cc index be31b5482937..65a4ec01805b 100644 --- a/src/relay/op/vm/vm.cc +++ b/src/relay/op/vm/vm.cc @@ -138,12 +138,7 @@ RELAY_REGISTER_OP("vm.shape_func") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("TNonComputational", true) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); // vm.invoke_tvm_op bool InvokeTVMOpRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -188,12 +183,7 @@ RELAY_REGISTER_OP("vm.invoke_tvm_op") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("TNonComputational", true) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); // vm.reshape TVM_REGISTER_NODE_TYPE(ReshapeTensorAttrs); From 88bf112454bbd30058719260e0fca9dd49c8692e Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 21 Oct 2021 04:09:37 -0700 Subject: [PATCH 64/84] Specify argument to FastMathFlags setAllowContract (#9337) --- src/target/llvm/llvm_module.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 657778df0e93..86079b25aa90 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -292,7 +292,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { Bool fast_math_afn = target->GetAttr("fast-math-afn").value_or(Bool(false)); Bool fast_math_reassoc = target->GetAttr("fast-math-reassoc").value_or(Bool(false)); if (fast_math_contract) { - fmf.setAllowContract(); + fmf.setAllowContract(true); } if (fast_math_afn) { fmf.setApproxFunc(); From e62075df1f6a2926aebf9b3655ba1284a2c1d8d2 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Thu, 21 Oct 2021 10:15:43 -0700 Subject: [PATCH 65/84] [microTVM][Arduino] Cleanup template directory (#9289) * restructure * readme * fix readme * trigger --- apps/microtvm/arduino/README.md | 18 ++++ .../standalone_crt/crt_config/crt_config.h | 55 ----------- .../arduino/template_project/boards.json | 59 +++++++++++ .../crt_config/crt_config.h | 0 .../template_project/microtvm_api_server.py | 98 +++++-------------- .../src/example_project}/model.c | 0 .../src/example_project}/model.h | 0 .../src}/example_project/project.ino | 0 .../src/host_driven}/model_support.c | 0 .../src}/host_driven/project.ino | 0 apps/microtvm/zephyr/README.md | 2 +- tests/micro/arduino/conftest.py | 20 ++-- 12 files changed, 110 insertions(+), 142 deletions(-) create mode 100644 apps/microtvm/arduino/README.md delete mode 100644 apps/microtvm/arduino/host_driven/src/standalone_crt/crt_config/crt_config.h create mode 100644 apps/microtvm/arduino/template_project/boards.json rename apps/microtvm/arduino/{example_project/src/standalone_crt => template_project}/crt_config/crt_config.h (100%) rename apps/microtvm/arduino/{example_project/src => template_project/src/example_project}/model.c (100%) rename apps/microtvm/arduino/{example_project/src => template_project/src/example_project}/model.h (100%) rename apps/microtvm/arduino/{ => template_project/src}/example_project/project.ino (100%) rename apps/microtvm/arduino/{host_driven/src => template_project/src/host_driven}/model_support.c (100%) rename apps/microtvm/arduino/{ => template_project/src}/host_driven/project.ino (100%) diff --git a/apps/microtvm/arduino/README.md b/apps/microtvm/arduino/README.md new file mode 100644 index 000000000000..b33557b53239 --- /dev/null +++ b/apps/microtvm/arduino/README.md @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + +This directory contains code to interface microTVM with [Arduino](https://www.arduino.cc/). diff --git a/apps/microtvm/arduino/host_driven/src/standalone_crt/crt_config/crt_config.h b/apps/microtvm/arduino/host_driven/src/standalone_crt/crt_config/crt_config.h deleted file mode 100644 index cf73103aff8b..000000000000 --- a/apps/microtvm/arduino/host_driven/src/standalone_crt/crt_config/crt_config.h +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief CRT configuration for the host-linked CRT. - */ -#ifndef TVM_RUNTIME_MICRO_CRT_CONFIG_H_ -#define TVM_RUNTIME_MICRO_CRT_CONFIG_H_ - -/*! Log level of the CRT runtime */ -#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG - -/*! Support low-level debugging in MISRA-C runtime */ -#define TVM_CRT_DEBUG 0 - -/*! Maximum supported dimension in NDArray */ -#define TVM_CRT_MAX_NDIM 6 -/*! Maximum supported arguments in generated functions */ -#define TVM_CRT_MAX_ARGS 10 -/*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ -#define TVM_CRT_MAX_STRLEN_DLTYPE 10 -/*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 - -/*! Maximum number of registered modules. */ -#define TVM_CRT_MAX_REGISTERED_MODULES 2 - -/*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512 - -/*! Maximum packet size, in bytes, including the length header. */ -#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8 * 1024 - -/*! \brief Maximum length of a PackedFunc function name. */ -#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 - -// #define TVM_CRT_FRAMER_ENABLE_LOGS - -#endif // TVM_RUNTIME_MICRO_CRT_CONFIG_H_ diff --git a/apps/microtvm/arduino/template_project/boards.json b/apps/microtvm/arduino/template_project/boards.json new file mode 100644 index 000000000000..595d56b5f615 --- /dev/null +++ b/apps/microtvm/arduino/template_project/boards.json @@ -0,0 +1,59 @@ +{ + "due": { + "package": "arduino", + "architecture": "sam", + "board": "arduino_due_x_dbg", + "model": "sam3x8e" + }, + "feathers2": { + "package": "esp32", + "architecture": "esp32", + "board": "feathers2", + "model": "esp32", + "note": "Due to the way the Feather S2 bootloader works, compilation behaves fine but uploads cannot be done automatically." + }, + "metrom4": { + "package": "adafruit", + "architecture": "samd", + "board": "adafruit_metro_m4", + "model": "atsamd51" + }, + "spresense": { + "package": "SPRESENSE", + "architecture": "spresense", + "board": "spresense", + "model": "cxd5602gg", + "note": "Spresense only works as of its v2.3.0 sdk." + }, + "nano33ble": { + "package": "arduino", + "architecture": "mbed_nano", + "board": "nano33ble", + "model": "nrf52840" + }, + "pybadge": { + "package": "adafruit", + "architecture": "samd", + "board": "adafruit_pybadge_m4", + "model": "atsamd51" + }, + "teensy40": { + "package": "teensy", + "architecture": "avr", + "board": "teensy40", + "model": "imxrt1060", + "note": "The Teensy boards are listed here for completeness, but they won't work until https://github.com/arduino/arduino-cli/issues/700 is finished." + }, + "teensy41": { + "package": "teensy", + "architecture": "avr", + "board": "teensy41", + "model": "imxrt1060" + }, + "wioterminal": { + "package": "Seeeduino", + "architecture": "samd", + "board": "seeed_wio_terminal", + "model": "atsamd51" + } +} diff --git a/apps/microtvm/arduino/example_project/src/standalone_crt/crt_config/crt_config.h b/apps/microtvm/arduino/template_project/crt_config/crt_config.h similarity index 100% rename from apps/microtvm/arduino/example_project/src/standalone_crt/crt_config/crt_config.h rename to apps/microtvm/arduino/template_project/crt_config/crt_config.h diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index 3d25d0bcad8f..e285ecc6e3b0 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -44,77 +44,21 @@ IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists() +BOARDS = API_SERVER_DIR / "boards.json" + +# Data structure to hold the information microtvm_api_server.py needs +# to communicate with each of these boards. +try: + with open(BOARDS) as boards: + BOARD_PROPERTIES = json.load(boards) +except FileNotFoundError: + raise FileNotFoundError(f"Board file {{{BOARDS}}} does not exist.") + + class BoardAutodetectFailed(Exception): """Raised when no attached hardware is found matching the requested board""" -# Data structure to hold the information microtvm_api_server.py needs -# to communicate with each of these boards. Currently just holds the -# components of each board's FQBN, but might be extended in the future -# to include the SRAM, PSRAM, flash, etc. on each board. -BOARD_PROPERTIES = { - "due": { - "package": "arduino", - "architecture": "sam", - "board": "arduino_due_x_dbg", - "model": "sam3x8e", - }, - # Due to the way the Feather S2 bootloader works, compilation - # behaves fine but uploads cannot be done automatically - "feathers2": { - "package": "esp32", - "architecture": "esp32", - "board": "feathers2", - "model": "esp32", - }, - "metrom4": { - "package": "adafruit", - "architecture": "samd", - "board": "adafruit_metro_m4", - "model": "atsamd51", - }, - # Spresense only works as of its v2.3.0 sdk - "spresense": { - "package": "SPRESENSE", - "architecture": "spresense", - "board": "spresense", - "model": "cxd5602gg", - }, - "nano33ble": { - "package": "arduino", - "architecture": "mbed_nano", - "board": "nano33ble", - "model": "nrf52840", - }, - "pybadge": { - "package": "adafruit", - "architecture": "samd", - "board": "adafruit_pybadge_m4", - "model": "atsamd51", - }, - # The Teensy boards are listed here for completeness, but they - # won't work until https://github.com/arduino/arduino-cli/issues/700 - # is finished - "teensy40": { - "package": "teensy", - "architecture": "avr", - "board": "teensy40", - "model": "imxrt1060", - }, - "teensy41": { - "package": "teensy", - "architecture": "avr", - "board": "teensy41", - "model": "imxrt1060", - }, - "wioterminal": { - "package": "Seeeduino", - "architecture": "samd", - "board": "seeed_wio_terminal", - "model": "atsamd51", - }, -} - PROJECT_TYPES = ["example_project", "host_driven"] PROJECT_OPTIONS = [ @@ -123,11 +67,6 @@ class BoardAutodetectFailed(Exception): choices=list(BOARD_PROPERTIES), help="Name of the Arduino board to build for", ), - server.ProjectOption( - "arduino_model", - choices=[board["model"] for _, board in BOARD_PROPERTIES.items()], - help="Name of the model for each Arduino board.", - ), server.ProjectOption("arduino_cli_cmd", help="Path to the arduino-cli tool."), server.ProjectOption("port", help="Port to use for connecting to hardware"), server.ProjectOption( @@ -166,8 +105,9 @@ def _copy_project_files(self, api_server_dir, project_dir, project_type): so this file is copied separately in generate_project. """ - project_types_folder = api_server_dir.parents[0] - for item in (project_types_folder / project_type / "src").iterdir(): + for item in (API_SERVER_DIR / "src" / project_type).iterdir(): + if item.name == "project.ino": + continue dest = project_dir / "src" / item.name if item.is_dir(): shutil.copytree(item, dest) @@ -176,7 +116,7 @@ def _copy_project_files(self, api_server_dir, project_dir, project_type): # Arduino requires the .ino file have the same filename as its containing folder shutil.copy2( - project_types_folder / project_type / "project.ino", + API_SERVER_DIR / "src" / project_type / "project.ino", project_dir / f"{project_dir.stem}.ino", ) @@ -344,12 +284,20 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec # Copies files from the template folder to project_dir shutil.copy2(API_SERVER_DIR / "microtvm_api_server.py", project_dir) + shutil.copy2(BOARDS, project_dir / BOARDS.name) self._copy_project_files(API_SERVER_DIR, project_dir, options["project_type"]) # Copy standalone_crt into src folder self._copy_standalone_crt(source_dir, standalone_crt_dir) self._remove_unused_components(source_dir, options["project_type"]) + # Populate crt-config.h + crt_config_dir = project_dir / "src" / "standalone_crt" / "crt_config" + crt_config_dir.mkdir() + shutil.copy2( + API_SERVER_DIR / "crt_config" / "crt_config.h", crt_config_dir / "crt_config.h" + ) + # Unpack the MLF and copy the relevant files metadata = self._disassemble_mlf(model_library_format_path, source_dir) shutil.copy2(model_library_format_path, source_dir / "model") diff --git a/apps/microtvm/arduino/example_project/src/model.c b/apps/microtvm/arduino/template_project/src/example_project/model.c similarity index 100% rename from apps/microtvm/arduino/example_project/src/model.c rename to apps/microtvm/arduino/template_project/src/example_project/model.c diff --git a/apps/microtvm/arduino/example_project/src/model.h b/apps/microtvm/arduino/template_project/src/example_project/model.h similarity index 100% rename from apps/microtvm/arduino/example_project/src/model.h rename to apps/microtvm/arduino/template_project/src/example_project/model.h diff --git a/apps/microtvm/arduino/example_project/project.ino b/apps/microtvm/arduino/template_project/src/example_project/project.ino similarity index 100% rename from apps/microtvm/arduino/example_project/project.ino rename to apps/microtvm/arduino/template_project/src/example_project/project.ino diff --git a/apps/microtvm/arduino/host_driven/src/model_support.c b/apps/microtvm/arduino/template_project/src/host_driven/model_support.c similarity index 100% rename from apps/microtvm/arduino/host_driven/src/model_support.c rename to apps/microtvm/arduino/template_project/src/host_driven/model_support.c diff --git a/apps/microtvm/arduino/host_driven/project.ino b/apps/microtvm/arduino/template_project/src/host_driven/project.ino similarity index 100% rename from apps/microtvm/arduino/host_driven/project.ino rename to apps/microtvm/arduino/template_project/src/host_driven/project.ino diff --git a/apps/microtvm/zephyr/README.md b/apps/microtvm/zephyr/README.md index ad00393c0805..68e9975d4b1c 100644 --- a/apps/microtvm/zephyr/README.md +++ b/apps/microtvm/zephyr/README.md @@ -15,5 +15,5 @@ -This directory code to interface microTVM with the [Zephyr RTOS](https://zephyrproject.org/). +This directory contains code to interface microTVM with the [Zephyr RTOS](https://zephyrproject.org/). diff --git a/tests/micro/arduino/conftest.py b/tests/micro/arduino/conftest.py index bb9c69bf4a0e..73361774821b 100644 --- a/tests/micro/arduino/conftest.py +++ b/tests/micro/arduino/conftest.py @@ -17,8 +17,9 @@ import datetime import pathlib - +import json import pytest + import tvm.target.target from tvm.micro import project from tvm import micro, relay @@ -34,19 +35,16 @@ / "template_project" ).resolve() +BOARDS = TEMPLATE_PROJECT_DIR / "boards.json" + def arduino_boards() -> dict: """Returns a dict mapping board to target model""" - template = project.TemplateProject.from_directory(TEMPLATE_PROJECT_DIR) - project_options = template.info()["project_options"] - for option in project_options: - if option["name"] == "arduino_board": - boards = option["choices"] - if option["name"] == "arduino_model": - models = option["choices"] - - arduino_boards = {boards[i]: models[i] for i in range(len(boards))} - return arduino_boards + with open(BOARDS) as f: + board_properties = json.load(f) + + boards_model = {board: info["model"] for board, info in board_properties.items()} + return boards_model ARDUINO_BOARDS = arduino_boards() From d11bdcd3ad0717b8e38ba769e849d6a6afe6415e Mon Sep 17 00:00:00 2001 From: Joe Chou <54378300+ccjoechou@users.noreply.github.com> Date: Thu, 21 Oct 2021 14:53:33 -0700 Subject: [PATCH 66/84] [Op] Do not override specified layout in pooling (2nd PR) (#9328) * [Op] Do not override specified layout in pooling (2nd PR) * [Op] Do not override specified layout in pooling (2nd PR) * [Op] Do not override specified layout in pooling (2nd PR) * [Op] Do not override specified layout in pooling (2nd PR) --- include/tvm/relay/attrs/nn.h | 78 +++++ python/tvm/relay/op/nn/_nn.py | 110 ++++++- python/tvm/relay/op/nn/nn.py | 180 +++++++++-- src/relay/op/nn/pooling.cc | 75 +++-- src/relay/op/nn/pooling.h | 6 +- src/relay/qnn/op/convolution.cc | 4 + src/relay/transforms/pattern_utils.h | 7 +- .../test_arm_compute_lib/test_pooling.py | 2 + .../relay/test_pass_convert_op_layout.py | 288 ++++++++++++++++++ 9 files changed, 675 insertions(+), 75 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index de60deb9cccb..26d2c72c824d 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -686,6 +686,7 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { Array padding; Array dilation; tvm::String layout; + tvm::String out_layout; bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") { @@ -709,6 +710,13 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Pooling is applied on the 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( "When true, will use ceil instead of floor to compute the output shape."); } @@ -721,6 +729,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { Array padding; Array dilation; tvm::String layout; + tvm::String out_layout; bool ceil_mode; bool count_include_pad; @@ -745,6 +754,13 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Pooling is applied on the 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( "When true, will use ceil instead of floor to compute the output shape."); TVM_ATTR_FIELD(count_include_pad) @@ -756,6 +772,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { /*! \brief Attributes for global pool operator */ struct GlobalPool2DAttrs : public tvm::AttrsNode { tvm::String layout; + tvm::String out_layout; TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") { TVM_ATTR_FIELD(layout).set_default("NCHW").describe( @@ -763,6 +780,13 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Pooling is applied on the 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); } }; @@ -770,6 +794,7 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode { struct AdaptivePool1DAttrs : public tvm::AttrsNode { Array output_size; std::string layout; + tvm::String out_layout; TVM_DECLARE_ATTRS(AdaptivePool1DAttrs, "relay.attrs.AdaptivePool1DAttrs") { TVM_ATTR_FIELD(output_size).set_default(Array({})).describe("Output width."); @@ -778,6 +803,13 @@ struct AdaptivePool1DAttrs : public tvm::AttrsNode { "'N', 'C', 'W' stands for batch, channel, and width" "dimensions respectively. Pooling is applied on the" "'W' dimension."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the" + "'W' dimension."); } }; @@ -785,6 +817,7 @@ struct AdaptivePool1DAttrs : public tvm::AttrsNode { struct AdaptivePool2DAttrs : public tvm::AttrsNode { Array output_size; std::string layout; + tvm::String out_layout; TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relay.attrs.AdaptivePool2DAttrs") { TVM_ATTR_FIELD(output_size) @@ -795,6 +828,13 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Pooling is applied on the 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); } }; @@ -802,6 +842,7 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode { struct AdaptivePool3DAttrs : public tvm::AttrsNode { Array output_size; std::string layout; + tvm::String out_layout; TVM_DECLARE_ATTRS(AdaptivePool3DAttrs, "relay.attrs.AdaptivePool3DAttrs") { TVM_ATTR_FIELD(output_size) @@ -812,6 +853,13 @@ struct AdaptivePool3DAttrs : public tvm::AttrsNode { "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" "dimensions respectively. Pooling is applied on 'D', 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on 'D', 'H' and" + "'W' dimensions."); } }; @@ -822,6 +870,7 @@ struct MaxPool1DAttrs : public tvm::AttrsNode { Array dilation; Array padding; std::string layout; + tvm::String out_layout; bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool1DAttrs, "relay.attrs.MaxPool1DAttrs") { @@ -844,6 +893,12 @@ struct MaxPool1DAttrs : public tvm::AttrsNode { "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." "'N', 'C', 'W' stands for batch, channel, and width" "dimensions respectively. Pooling is applied on the 'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the 'W' dimensions."); TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( "When true, will use ceil instead of floor to compute the output shape."); } @@ -856,6 +911,7 @@ struct AvgPool1DAttrs : public tvm::AttrsNode { Array dilation; Array padding; std::string layout; + tvm::String out_layout; bool ceil_mode; bool count_include_pad; @@ -879,6 +935,12 @@ struct AvgPool1DAttrs : public tvm::AttrsNode { "Dimension ordering of input data. Can be 'NCW', 'NHC', etc." "'N', 'C', 'W' stands for batch, channel, and width" "dimensions respectively. Pooling is applied on the 'W' dimension."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCW', 'NHC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the 'W' dimension."); TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( "When true, will use ceil instead of floor to compute the output shape."); TVM_ATTR_FIELD(count_include_pad) @@ -894,6 +956,7 @@ struct MaxPool3DAttrs : public tvm::AttrsNode { Array dilation; Array padding; std::string layout; + tvm::String out_layout; bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool3DAttrs, "relay.attrs.MaxPool3DAttrs") { @@ -917,6 +980,13 @@ struct MaxPool3DAttrs : public tvm::AttrsNode { "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" "dimensions respectively. Pooling is applied on the 'D', 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on the 'D', 'H' and" + "'W' dimensions."); TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( "When true, will use ceil instead of floor to compute the output shape."); } @@ -929,6 +999,7 @@ struct AvgPool3DAttrs : public tvm::AttrsNode { Array dilation; Array padding; std::string layout; + tvm::String out_layout; bool ceil_mode; bool count_include_pad; @@ -953,6 +1024,13 @@ struct AvgPool3DAttrs : public tvm::AttrsNode { "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" "dimensions respectively. Pooling is applied on the 'D', 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on the 'D', 'H' and" + "'W' dimensions."); TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( "When true, will use ceil instead of floor to compute the output shape."); TVM_ATTR_FIELD(count_include_pad) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index f06ee09fc7f4..17f75a07af64 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -18,7 +18,7 @@ """Backend compiler related feature registration""" from __future__ import absolute_import -from tvm import topi +from tvm import topi, relay from tvm.topi.utils import get_const_tuple from tvm.runtime import convert @@ -267,9 +267,6 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts): result : tvm.relay.Expr The transformed expr """ - # pylint: disable=import-outside-toplevel - from tvm import relay - data, weight = inputs # First check if there is a LayoutConfig scope, and if so, whether @@ -363,9 +360,6 @@ def convert_conv2d_transpose(attrs, inputs, tinfos, desired_layouts): result : tvm.relay.Expr The transformed expr """ - # pylint: disable=import-outside-toplevel - from tvm import relay - data, weight = inputs new_attrs = dict(attrs) assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs" @@ -446,9 +440,6 @@ def convert_conv3d(attrs, inputs, tinfos, desired_layouts): result : tvm.relay.Expr The transformed expr """ - # pylint: disable=import-outside-toplevel - from tvm import relay - data, weight = inputs new_attrs = dict(attrs) assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv3d's inputs" @@ -515,6 +506,30 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype): reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_convert_op_layout("nn.max_pool2d") +def convert_max_pool2d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for max_pool2d op. + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current pooling + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of one layout string + layout string defining our desired layout for input and output. + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + new_attrs = dict(attrs) + new_attrs["layout"] = str(desired_layouts[0]) + new_attrs["out_layout"] = str(desired_layouts[0]) + return relay.nn.max_pool2d(*inputs, **new_attrs) + + # max_pool3d reg.register_schedule("nn.max_pool3d", strategy.schedule_pool) reg.register_pattern("nn.max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -530,6 +545,30 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype): reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_convert_op_layout("nn.avg_pool2d") +def convert_avg_pool2d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for avg_pool2d op. + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current pooling + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of one layout string + layout string defining our desired layout for input and output. + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + new_attrs = dict(attrs) + new_attrs["layout"] = str(desired_layouts[0]) + new_attrs["out_layout"] = str(desired_layouts[0]) + return relay.nn.avg_pool2d(*inputs, **new_attrs) + + # avg_pool3d reg.register_schedule("nn.avg_pool3d", strategy.schedule_pool) reg.register_pattern("nn.avg_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -560,11 +599,59 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype): reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_convert_op_layout("nn.global_max_pool2d") +def convert_global_max_pool2d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for global_max_pool2d op. + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current pooling + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of one layout string + layout string defining our desired layout for input and output. + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + new_attrs = dict(attrs) + new_attrs["layout"] = str(desired_layouts[0]) + new_attrs["out_layout"] = str(desired_layouts[0]) + return relay.nn.global_max_pool2d(*inputs, **new_attrs) + + # global_avg_pool2d reg.register_schedule("nn.global_avg_pool2d", strategy.schedule_adaptive_pool) reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_convert_op_layout("nn.global_avg_pool2d") +def convert_global_avg_pool2d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for global_avg_pool2d op. + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current pooling + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of one layout string + layout string defining our desired layout for input and output. + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + new_attrs = dict(attrs) + new_attrs["layout"] = str(desired_layouts[0]) + new_attrs["out_layout"] = str(desired_layouts[0]) + return relay.nn.global_avg_pool2d(*inputs, **new_attrs) + + # adaptive_max_pool2d reg.register_schedule("nn.adaptive_max_pool2d", strategy.schedule_adaptive_pool) reg.register_pattern("nn.adaptive_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -796,9 +883,6 @@ def convert_deformable_conv2d(attrs, inputs, tinfos, desired_layouts): result : tvm.relay.Expr The transformed expr """ - # pylint: disable=import-outside-toplevel - from tvm import relay - data, offset, weight = inputs new_attrs = dict(attrs) for attr in new_attrs: diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 5a17db745b3e..1821ff17258a 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -748,7 +748,14 @@ def log_softmax(data, axis=-1): def max_pool1d( - data, pool_size=(1,), strides=(1,), dilation=(1,), padding=(0,), layout="NCW", ceil_mode=False + data, + pool_size=(1,), + strides=(1,), + dilation=(1,), + padding=(0,), + layout="NCW", + out_layout="", + ceil_mode=False, ): r"""1D maximum pooling operator. @@ -783,6 +790,9 @@ def max_pool1d( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -798,7 +808,9 @@ def max_pool1d( if isinstance(dilation, int): dilation = (dilation,) padding = get_pad_tuple1d(padding) - return _make.max_pool1d(data, pool_size, strides, dilation, padding, layout, ceil_mode) + return _make.max_pool1d( + data, pool_size, strides, dilation, padding, layout, out_layout, ceil_mode + ) def max_pool2d( @@ -808,6 +820,7 @@ def max_pool2d( dilation=(1, 1), padding=(0, 0), layout="NCHW", + out_layout="", ceil_mode=False, ): r"""2D maximum pooling operator. @@ -851,6 +864,9 @@ def max_pool2d( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -866,7 +882,9 @@ def max_pool2d( if isinstance(dilation, int): dilation = (dilation, dilation) padding = get_pad_tuple2d(padding) - return _make.max_pool2d(data, pool_size, strides, dilation, padding, layout, ceil_mode) + return _make.max_pool2d( + data, pool_size, strides, dilation, padding, layout, out_layout, ceil_mode + ) def max_pool3d( @@ -876,6 +894,7 @@ def max_pool3d( dilation=(1, 1, 1), padding=(0, 0, 0), layout="NCDHW", + out_layout="", ceil_mode=False, ): r"""3D maximum pooling operator. @@ -912,6 +931,9 @@ def max_pool3d( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -927,7 +949,9 @@ def max_pool3d( if isinstance(dilation, int): dilation = (dilation, dilation, dilation) padding = get_pad_tuple3d(padding) - return _make.max_pool3d(data, pool_size, strides, dilation, padding, layout, ceil_mode) + return _make.max_pool3d( + data, pool_size, strides, dilation, padding, layout, out_layout, ceil_mode + ) def avg_pool1d( @@ -937,6 +961,7 @@ def avg_pool1d( dilation=(1,), padding=(0,), layout="NCW", + out_layout="", ceil_mode=False, count_include_pad=False, ): @@ -973,6 +998,9 @@ def avg_pool1d( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -992,7 +1020,15 @@ def avg_pool1d( dilation = (dilation,) padding = get_pad_tuple1d(padding) return _make.avg_pool1d( - data, pool_size, strides, dilation, padding, layout, ceil_mode, count_include_pad + data, + pool_size, + strides, + dilation, + padding, + layout, + out_layout, + ceil_mode, + count_include_pad, ) @@ -1003,6 +1039,7 @@ def avg_pool2d( dilation=(1, 1), padding=(0, 0), layout="NCHW", + out_layout="", ceil_mode=False, count_include_pad=False, ): @@ -1048,6 +1085,9 @@ def avg_pool2d( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -1067,7 +1107,15 @@ def avg_pool2d( dilation = (dilation, dilation) padding = get_pad_tuple2d(padding) return _make.avg_pool2d( - data, pool_size, strides, dilation, padding, layout, ceil_mode, count_include_pad + data, + pool_size, + strides, + dilation, + padding, + layout, + out_layout, + ceil_mode, + count_include_pad, ) @@ -1078,6 +1126,7 @@ def avg_pool3d( dilation=(1, 1, 1), padding=(0, 0, 0), layout="NCDHW", + out_layout="", ceil_mode=False, count_include_pad=False, ): @@ -1115,6 +1164,9 @@ def avg_pool3d( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -1134,7 +1186,15 @@ def avg_pool3d( dilation = (dilation, dilation, dilation) padding = get_pad_tuple3d(padding) return _make.avg_pool3d( - data, pool_size, strides, dilation, padding, layout, ceil_mode, count_include_pad + data, + pool_size, + strides, + dilation, + padding, + layout, + out_layout, + ceil_mode, + count_include_pad, ) @@ -1145,6 +1205,7 @@ def max_pool2d_grad( strides=(1, 1), padding=(0, 0), layout="NCHW", + out_layout="", ceil_mode=False, ): r"""Gradient of 2D maximum pooling operator. @@ -1171,6 +1232,9 @@ def max_pool2d_grad( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -1179,7 +1243,9 @@ def max_pool2d_grad( result : tvm.relay.Expr The computed result. """ - return _make.max_pool2d_grad(out_grad, data, pool_size, strides, padding, layout, ceil_mode) + return _make.max_pool2d_grad( + out_grad, data, pool_size, strides, padding, layout, out_layout, ceil_mode + ) def avg_pool2d_grad( @@ -1189,6 +1255,7 @@ def avg_pool2d_grad( strides=(1, 1), padding=(0, 0), layout="NCHW", + out_layout="", ceil_mode=False, count_include_pad=False, ): @@ -1216,6 +1283,9 @@ def avg_pool2d_grad( layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + ceil_mode : bool, optional To enable or disable ceil while pooling. @@ -1228,11 +1298,19 @@ def avg_pool2d_grad( The computed result. """ return _make.avg_pool2d_grad( - out_grad, data, pool_size, strides, padding, layout, ceil_mode, count_include_pad + out_grad, + data, + pool_size, + strides, + padding, + layout, + out_layout, + ceil_mode, + count_include_pad, ) -def global_max_pool2d(data, layout="NCHW"): +def global_max_pool2d(data, layout="NCHW", out_layout=""): r"""2D global maximum pooling operator. This operator takes data as input and does 2D max value calculation @@ -1258,15 +1336,18 @@ def global_max_pool2d(data, layout="NCHW"): layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + Returns ------- result : tvm.relay.Expr The computed result. """ - return _make.global_max_pool2d(data, layout) + return _make.global_max_pool2d(data, layout, out_layout) -def global_avg_pool2d(data, layout="NCHW"): +def global_avg_pool2d(data, layout="NCHW", out_layout=""): r"""2D global average pooling operator. This operator takes data as input and does 2D average value calculation @@ -1292,12 +1373,15 @@ def global_avg_pool2d(data, layout="NCHW"): layout : str, optional Layout of the input. + out_layout : Optional[str] + Layout of the output + Returns ------- result : tvm.relay.Expr The computed result. """ - return _make.global_avg_pool2d(data, layout) + return _make.global_avg_pool2d(data, layout, out_layout) def upsampling( @@ -3114,7 +3198,7 @@ def space_to_depth(data, block_size, layout="NCHW"): return _make.space_to_depth(data, block_size, layout) -def adaptive_max_pool1d(data, output_size=None, layout="NCW"): +def adaptive_max_pool1d(data, output_size=None, layout="NCW", out_layout=""): r"""1D adaptive max pooling operator. This operator is experimental. This operator takes data as input and does 1D max value calculation @@ -3147,6 +3231,9 @@ def adaptive_max_pool1d(data, output_size=None, layout="NCW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr @@ -3155,10 +3242,10 @@ def adaptive_max_pool1d(data, output_size=None, layout="NCW"): output_size = [] or output_size if isinstance(output_size, int): output_size = [output_size] - return _make.adaptive_max_pool1d(data, output_size, layout) + return _make.adaptive_max_pool1d(data, output_size, layout, out_layout) -def adaptive_avg_pool1d(data, output_size=None, layout="NCW"): +def adaptive_avg_pool1d(data, output_size=None, layout="NCW", out_layout=""): r"""1D adaptive average pooling operator. This operator is experimental. This operator takes data as input and does 1D average value calculation @@ -3191,6 +3278,9 @@ def adaptive_avg_pool1d(data, output_size=None, layout="NCW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr @@ -3199,10 +3289,10 @@ def adaptive_avg_pool1d(data, output_size=None, layout="NCW"): output_size = [] or output_size if isinstance(output_size, int): output_size = [output_size] - return _make.adaptive_avg_pool1d(data, output_size, layout) + return _make.adaptive_avg_pool1d(data, output_size, layout, out_layout) -def adaptive_max_pool2d(data, output_size=None, layout="NCHW"): +def adaptive_max_pool2d(data, output_size=None, layout="NCHW", out_layout=""): r"""2D adaptive max pooling operator. This operator is experimental. This operator takes data as input and does 2D max value calculation @@ -3238,16 +3328,19 @@ def adaptive_max_pool2d(data, output_size=None, layout="NCHW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [] or output_size - return _make.adaptive_max_pool2d(data, output_size, layout) + return _make.adaptive_max_pool2d(data, output_size, layout, out_layout) -def adaptive_avg_pool2d(data, output_size=None, layout="NCHW"): +def adaptive_avg_pool2d(data, output_size=None, layout="NCHW", out_layout=""): r"""2D adaptive average pooling operator. This operator is experimental. This operator takes data as input and does 2D average value calculation @@ -3283,16 +3376,19 @@ def adaptive_avg_pool2d(data, output_size=None, layout="NCHW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [] or output_size - return _make.adaptive_avg_pool2d(data, output_size, layout) + return _make.adaptive_avg_pool2d(data, output_size, layout, out_layout) -def adaptive_max_pool3d(data, output_size=None, layout="NCDHW"): +def adaptive_max_pool3d(data, output_size=None, layout="NCDHW", out_layout=""): r"""3D adaptive max pooling operator. This operator is experimental. This operator takes data as input and does 3D max value calculation @@ -3327,16 +3423,19 @@ def adaptive_max_pool3d(data, output_size=None, layout="NCDHW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [] or output_size - return _make.adaptive_max_pool3d(data, output_size, layout) + return _make.adaptive_max_pool3d(data, output_size, layout, out_layout) -def adaptive_avg_pool3d(data, output_size=None, layout="NCDHW"): +def adaptive_avg_pool3d(data, output_size=None, layout="NCDHW", out_layout=""): r"""3D adaptive avg pooling operator. This operator is experimental. This operator takes data as input and does 3D avg value calculation @@ -3371,16 +3470,19 @@ def adaptive_avg_pool3d(data, output_size=None, layout="NCDHW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [] or output_size - return _make.adaptive_avg_pool3d(data, output_size, layout) + return _make.adaptive_avg_pool3d(data, output_size, layout, out_layout) -def global_max_pool1d(data, layout="NCW"): +def global_max_pool1d(data, layout="NCW", out_layout=""): r"""1D global maximum pooling operator. This operator takes data as input and does 1D max value calculation @@ -3403,16 +3505,19 @@ def global_max_pool1d(data, layout="NCW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [1] - return _make.adaptive_max_pool1d(data, output_size, layout) + return _make.adaptive_max_pool1d(data, output_size, layout, out_layout) -def global_avg_pool1d(data, layout="NCW"): +def global_avg_pool1d(data, layout="NCW", out_layout=""): r"""1D global average pooling operator. This operator takes data as input and does 1D average value calculation @@ -3436,16 +3541,19 @@ def global_avg_pool1d(data, layout="NCW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [1] - return _make.adaptive_avg_pool1d(data, output_size, layout) + return _make.adaptive_avg_pool1d(data, output_size, layout, out_layout) -def global_max_pool3d(data, layout="NCDHW"): +def global_max_pool3d(data, layout="NCDHW", out_layout=""): r"""3D global maximum pooling operator. This operator takes data as input and does 3D max value calculation @@ -3469,16 +3577,19 @@ def global_max_pool3d(data, layout="NCDHW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [1, 1, 1] - return _make.adaptive_max_pool3d(data, output_size, layout) + return _make.adaptive_max_pool3d(data, output_size, layout, out_layout) -def global_avg_pool3d(data, layout="NCDHW"): +def global_avg_pool3d(data, layout="NCDHW", out_layout=""): r"""3D global average pooling operator. This operator takes data as input and does 3D average value calculation @@ -3503,13 +3614,16 @@ def global_avg_pool3d(data, layout="NCDHW"): layout : str, optional Layout of the input. + out_layout : str, optional + Layout of the output. + Returns ------- result : tvm.relay.Expr The computed result. """ output_size = [1, 1, 1] - return _make.adaptive_avg_pool3d(data, output_size, layout) + return _make.adaptive_avg_pool3d(data, output_size, layout, out_layout) def correlation( diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 0d40caa15052..cf44b308ce02 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -49,8 +49,13 @@ InferCorrectLayoutOutput PoolInferCorrectLayout(const Attrs& attrs, ICHECK(attrs_ptr); ObjectPtr params = make_object(*attrs_ptr); - if (new_in_layouts.defined()) { - // Set the pool with the new layout. + if (params->out_layout != "") { + // when users specify the out_layout of pooling, follow user's preference + ICHECK_EQ(params->layout, params->out_layout) + << "Pooling input/output layouts mismatch: " << params->layout << " vs. " + << params->out_layout; + } else if (new_in_layouts.defined()) { + // the pooling is using an inferred layout (i.e., new_in_layouts[0]) given by relay caller ICHECK_EQ(new_in_layouts.size(), 1); params->layout = new_in_layouts[0].name(); } @@ -144,6 +149,7 @@ Array Pool2DCompute(const Attrs& attrs, const Array& inp auto padding = param->padding; auto ceil_mode = param->ceil_mode; Layout layout(param->layout); + Layout out_layout(param->out_layout); ICHECK(tir::BijectiveLayout(layout, kNCHW).defined()) << "max_pool2d currently only supports layouts that are convertible from NCHW"; @@ -178,9 +184,9 @@ Array Pool2DCompute(const Attrs& attrs, const Array& inp TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d") .set_body_typed([](Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode) { + String out_layout, bool ceil_mode) { return MakeMaxPool(data, pool_size, strides, dilation, padding, layout, - ceil_mode, "nn.max_pool2d"); + out_layout, ceil_mode, "nn.max_pool2d"); }); RELAY_REGISTER_OP("nn.max_pool2d") @@ -216,9 +222,9 @@ RELAY_REGISTER_OP("nn.max_pool2d") TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d") .set_body_typed([](Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode, bool count_include_pad) { + String out_layout, bool ceil_mode, bool count_include_pad) { return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, - ceil_mode, count_include_pad, "nn.avg_pool2d"); + out_layout, ceil_mode, count_include_pad, "nn.avg_pool2d"); }); RELAY_REGISTER_OP("nn.avg_pool2d") @@ -303,9 +309,10 @@ Array GlobalPool2DCompute(const Attrs& attrs, const Array{topi::nn::global_pool(inputs[0], mode, layout.name())}; } -Expr MakeGlobalAvgPool2D(Expr data, String layout) { +Expr MakeGlobalAvgPool2D(Expr data, String layout, String out_layout) { auto attrs = make_object(); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.global_avg_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -331,9 +338,10 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") .set_attr("FTVMCompute", GlobalPool2DCompute); // GlobalMaxPool -Expr MakeGlobalMaxPool2D(Expr data, String layout) { +Expr MakeGlobalMaxPool2D(Expr data, String layout, String out_layout) { auto attrs = make_object(); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.global_max_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -423,10 +431,12 @@ Array AdaptivePool1DCompute(const Attrs& attrs, const Array output_size, String layout) { +Expr MakeAdaptiveAvgPool1D(Expr data, Array output_size, String layout, + String out_layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.adaptive_avg_pool1d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -456,10 +466,12 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool1d") .set_attr("FTVMCompute", AdaptivePool1DCompute); // relay.nn.adaptive_max_pool1d -Expr MakeAdaptiveMaxPool1D(Expr data, Array output_size, String layout) { +Expr MakeAdaptiveMaxPool1D(Expr data, Array output_size, String layout, + String out_layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.adaptive_max_pool1d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -571,10 +583,12 @@ Array AdaptivePool2DCompute(const Attrs& attrs, const Array output_size, String layout) { +Expr MakeAdaptiveAvgPool2D(Expr data, Array output_size, String layout, + String out_layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.adaptive_avg_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -606,10 +620,12 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool2d") .set_attr("FTVMCompute", AdaptivePool2DCompute); // relay.nn.adaptive_max_pool2d -Expr MakeAdaptiveMaxPool2D(Expr data, Array output_size, String layout) { +Expr MakeAdaptiveMaxPool2D(Expr data, Array output_size, String layout, + String out_layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.adaptive_max_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -700,6 +716,7 @@ Array AdaptivePool3DCompute(const Attrs& attrs, const Array(); ICHECK(param != nullptr); Layout layout(param->layout); + Layout out_layout(param->out_layout); ICHECK(tir::BijectiveLayout(layout, kNCDHW).defined()) << "Adaptive pool3d currently only supports layouts that are convertible from NCDHW"; ICHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1) @@ -737,10 +754,12 @@ Array AdaptivePool3DCompute(const Attrs& attrs, const Array output_size, String layout) { +Expr MakeAdaptiveMaxPool3D(Expr data, Array output_size, String layout, + String out_layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.adaptive_max_pool3d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -772,10 +791,12 @@ RELAY_REGISTER_OP("nn.adaptive_max_pool3d") .set_attr("FTVMCompute", AdaptivePool3DCompute); // relay.nn.adaptive_max_pool3d -Expr MakeAdaptiveAvgPool3D(Expr data, Array output_size, String layout) { +Expr MakeAdaptiveAvgPool3D(Expr data, Array output_size, String layout, + String out_layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); static const Op& op = Op::Get("nn.adaptive_avg_pool3d"); return Call(op, {data}, Attrs(attrs), {}); } @@ -866,12 +887,13 @@ Array Pool2DGradCompute(const Attrs& attrs, const Array& // MaxPool2DGrad Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array pool_size, Array strides, Array padding, String layout, - bool ceil_mode) { + String out_layout, bool ceil_mode) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); attrs->ceil_mode = ceil_mode; static const Op& op = Op::Get("nn.max_pool2d_grad"); return Call(op, {out_grad, data}, Attrs(attrs), {}); @@ -913,12 +935,13 @@ RELAY_REGISTER_OP("nn.max_pool2d_grad") // AvgPool2DGrad Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array pool_size, Array strides, Array padding, String layout, - bool ceil_mode, bool count_include_pad) { + String out_layout, bool ceil_mode, bool count_include_pad) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); attrs->ceil_mode = ceil_mode; attrs->count_include_pad = count_include_pad; static const Op& op = Op::Get("nn.avg_pool2d_grad"); @@ -976,6 +999,7 @@ bool Pool1DRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK(param != nullptr); Layout layout(param->layout); + Layout out_layout(param->out_layout); ICHECK(layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('w'))) << "Invalid layout " << layout << ". Pool1D layout must have W, which cannot be split"; @@ -1018,6 +1042,7 @@ Array Pool1DCompute(const Attrs& attrs, const Array& inp auto padding = param->padding; auto ceil_mode = param->ceil_mode; Layout layout(param->layout); + Layout out_layout(param->out_layout); ICHECK(tir::BijectiveLayout(layout, kNCW).defined()) << "max_pool1d currently only supports layouts that are convertible from NCW"; @@ -1046,9 +1071,9 @@ Array Pool1DCompute(const Attrs& attrs, const Array& inp TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool1d") .set_body_typed([](Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode) { + String out_layout, bool ceil_mode) { return MakeMaxPool(data, pool_size, strides, dilation, padding, layout, - ceil_mode, "nn.max_pool1d"); + out_layout, ceil_mode, "nn.max_pool1d"); }); RELAY_REGISTER_OP("nn.max_pool1d") @@ -1082,9 +1107,9 @@ RELAY_REGISTER_OP("nn.max_pool1d") TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool1d") .set_body_typed([](Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode, bool count_include_pad) { + String out_layout, bool ceil_mode, bool count_include_pad) { return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, - ceil_mode, count_include_pad, "nn.avg_pool1d"); + out_layout, ceil_mode, count_include_pad, "nn.avg_pool1d"); }); RELAY_REGISTER_OP("nn.avg_pool1d") @@ -1134,6 +1159,7 @@ bool Pool3DRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK(param != nullptr); Layout layout(param->layout); + Layout out_layout(param->out_layout); ICHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) @@ -1194,6 +1220,7 @@ Array Pool3DCompute(const Attrs& attrs, const Array& inp auto padding = param->padding; auto ceil_mode = param->ceil_mode; Layout layout(param->layout); + Layout out_layout(param->out_layout); ICHECK(tir::BijectiveLayout(layout, kNCDHW).defined()) << "max_pool3d currently only supports layouts that are convertible from NCDHW"; @@ -1231,9 +1258,9 @@ Array Pool3DCompute(const Attrs& attrs, const Array& inp TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d") .set_body_typed([](Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode) { + String out_layout, bool ceil_mode) { return MakeMaxPool(data, pool_size, strides, dilation, padding, layout, - ceil_mode, "nn.max_pool3d"); + out_layout, ceil_mode, "nn.max_pool3d"); }); RELAY_REGISTER_OP("nn.max_pool3d") @@ -1270,9 +1297,9 @@ RELAY_REGISTER_OP("nn.max_pool3d") TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d") .set_body_typed([](Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode, bool count_include_pad) { + String out_layout, bool ceil_mode, bool count_include_pad) { return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, - ceil_mode, count_include_pad, "nn.avg_pool3d"); + out_layout, ceil_mode, count_include_pad, "nn.avg_pool3d"); }); RELAY_REGISTER_OP("nn.avg_pool3d") diff --git a/src/relay/op/nn/pooling.h b/src/relay/op/nn/pooling.h index 9b7eab25fe9a..32ae464101ab 100644 --- a/src/relay/op/nn/pooling.h +++ b/src/relay/op/nn/pooling.h @@ -35,13 +35,14 @@ namespace relay { template inline Expr MakeMaxPool(Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode, String op_name) { + String out_layout, bool ceil_mode, String op_name) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->dilation = std::move(dilation); attrs->padding = std::move(padding); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); attrs->ceil_mode = ceil_mode; static const Op& op = Op::Get(op_name); return Call(op, {data}, Attrs(attrs), {}); @@ -50,13 +51,14 @@ inline Expr MakeMaxPool(Expr data, Array pool_size, Array template inline Expr MakeAvgPool(Expr data, Array pool_size, Array strides, Array dilation, Array padding, String layout, - bool ceil_mode, bool count_include_pad, String op_name) { + String out_layout, bool ceil_mode, bool count_include_pad, String op_name) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->dilation = std::move(dilation); attrs->padding = std::move(padding); attrs->layout = std::move(layout); + attrs->out_layout = std::move(out_layout); attrs->ceil_mode = ceil_mode; attrs->count_include_pad = count_include_pad; static const Op& op = Op::Get(op_name); diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 5782f1f6b4d1..ecdd36ddb791 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -275,6 +275,7 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_ Array padding({0, 0}); reduced_t2 = AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, param->dilation, padding, param->data_layout, + "", // out_layout false, // ceil_mode false); // count_include_pad } else { @@ -284,6 +285,7 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_ Array padding({0, 0}); reduced_t2 = AvgPool2D(reduced_t2, param->kernel_size, param->strides, param->dilation, padding, param->data_layout, + "", // out_layout false, // ceil_mode false); // count_include_pad } @@ -463,6 +465,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, Multiply(reduced_c_t2, MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w)); reduced_t2 = AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, param->dilation, padding, param->data_layout, + "", // out_layout false, // ceil_mode false); // count_include_pad } else { @@ -471,6 +474,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, if (stride1 * stride2 != 1) { reduced_t2 = AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, param->dilation, padding, param->data_layout, + "", // out_layout false, // ceil_mode false); // count_include_pad } diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 692ef3c9f557..03b8ee6937a7 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -676,9 +676,10 @@ static inline Expr Reshape(Expr data, Array newshape) { static inline Expr AvgPool2D(Expr data, Array pool_size, Array strides, Array dilation, Array padding, - std::string layout, bool ceil_mode, bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool2d"); + std::string layout, std::string out_layout, bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, + out_layout, ceil_mode, count_include_pad, "nn.avg_pool2d"); } static inline Expr Pad(Expr data, Array> pad_width, Expr pad_value, diff --git a/tests/python/contrib/test_arm_compute_lib/test_pooling.py b/tests/python/contrib/test_arm_compute_lib/test_pooling.py index 9deaa758639e..b174f9a78866 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_pooling.py +++ b/tests/python/contrib/test_arm_compute_lib/test_pooling.py @@ -123,6 +123,7 @@ def _get_expected_pooling_codegen( "num_inputs": "1", "num_outputs": "1", "layout": [["NHWC"]], + "out_layout": [[""]], "shape": [[list(output_shape)]], "dtype": [[dtype]], "padding": [[str(p) for p in padding]], @@ -149,6 +150,7 @@ def _get_expected_global_pooling_codegen(shape, dtype, typef): "num_inputs": "1", "num_outputs": "1", "layout": [["NHWC"]], + "out_layout": [[""]], "shape": [[[1, 1, 1, shape[3]]]], "dtype": [[dtype]], }, diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 9b4d154360b2..2359dcdf93d9 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -248,6 +248,61 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_conv_bias_pool_uses_specified_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.bias_add(y, bias, axis=3) + # a useless tuple, which will be eliminated + y = relay.Tuple([y])[0] + y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NHWC") + y = relay.cast(y, "int32") + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, "NHWC", "NCHW") + weight = relay.layout_transform(weight, "HWIO", "OIHW") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + + bias = relay.expand_dims(bias, axis=0, num_newaxis=3) + bias = relay.layout_transform(bias, "NHWC", "NCHW") + y = relay.add(y, bias) + # a useless tuple, which will be eliminated + y = relay.Tuple([y])[0] + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NHWC", out_layout="NHWC") + y = relay.cast(y, "int32") + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass( + a, + transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"], "nn.max_pool2d": ["NHWC"]}), + ) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + + def test_conv_concat_convert_layout(): def before(): x = relay.var("x", shape=(1, 56, 56, 64)) @@ -412,6 +467,139 @@ def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_deformable_conv_bias_pool_uses_specified_convert_layout(): + def before(N, CI, H, W, CO, KH, KW, layout): + if layout == "NCHW": + data_shape = (N, CI, H, W) + weight_shape = (CO, CI, KH, KW) + kernel_layout = "OIHW" + else: + data_shape = (N, H, W, CI) + weight_shape = (KH, KW, CI, CO) + kernel_layout = "HWIO" + bias_shape = (CO,) + + data = relay.var("data", shape=data_shape, dtype="float32") + offset = relay.var("offset") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + bias = relay.var("bias", shape=bias_shape, dtype="float32") + + y = relay.nn.deformable_conv2d( + data, + offset, + weight, + kernel_size=(KH, KW), + channels=CO, + data_layout=layout, + kernel_layout=kernel_layout, + ) + y = relay.nn.bias_add(y, bias, axis=-1 if layout == "NHWC" else 1) + y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout=layout) + y = relay.cast(y, "int32") + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout, max_pool_layout=None): + layout_map = {"src": {}, "dst": {}} + if src_layout == "NCHW": + nchw = layout_map["src"] + nhwc = layout_map["dst"] + else: + nchw = layout_map["dst"] + nhwc = layout_map["src"] + + nchw["data_layout"] = "NCHW" + nchw["data_shape"] = (N, CI, H, W) + nchw["offset_shape"] = (N, KH * KW * 2, OH, OW) + nchw["weight_shape"] = (CO, CI, KH, KW) + nchw["kernel_layout"] = "OIHW" + + nhwc["data_layout"] = "NHWC" + nhwc["data_shape"] = (N, H, W, CI) + nhwc["offset_shape"] = (N, OH, OW, KH * KW * 2) + nhwc["weight_shape"] = (KH, KW, CI, CO) + nhwc["kernel_layout"] = "HWIO" + + bias_shape = (CO,) + + data = relay.var("data", shape=layout_map["src"]["data_shape"], dtype="float32") + offset = relay.var("offset", shape=layout_map["src"]["offset_shape"], dtype="float32") + weight = relay.var("weight", shape=layout_map["src"]["weight_shape"], dtype="float32") + bias = relay.var("bias", shape=bias_shape, dtype="float32") + + data = relay.layout_transform( + data, layout_map["src"]["data_layout"], layout_map["dst"]["data_layout"] + ) + offset = relay.layout_transform( + offset, layout_map["src"]["data_layout"], layout_map["dst"]["data_layout"] + ) + weight = relay.layout_transform( + weight, layout_map["src"]["kernel_layout"], layout_map["dst"]["kernel_layout"] + ) + y = relay.nn.deformable_conv2d( + data, + offset, + weight, + kernel_size=(KH, KW), + channels=CO, + data_layout=layout_map["dst"]["data_layout"], + kernel_layout=layout_map["dst"]["kernel_layout"], + ) + if layout_map["src"]["data_layout"] == "NHWC": + bias = relay.expand_dims(bias, axis=0, num_newaxis=3) + else: + bias = relay.expand_dims(bias, axis=1, num_newaxis=2) + bias = relay.expand_dims(bias, axis=0) + bias = relay.layout_transform( + bias, layout_map["src"]["data_layout"], layout_map["dst"]["data_layout"] + ) + y = relay.add(y, bias) + y = relay.nn.relu(y) + if max_pool_layout != layout_map["dst"]["data_layout"]: + y = relay.layout_transform(y, layout_map["dst"]["data_layout"], max_pool_layout) + y = relay.nn.max_pool2d( + y, pool_size=(2, 2), layout=max_pool_layout, out_layout=max_pool_layout + ) + y = relay.cast(y, "int32") + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + # NHWC -> NCHW + a = before(1, 3, 224, 224, 32, 3, 3, "NHWC") + a = run_opt_pass( + a, + transform.ConvertLayout( + {"nn.deformable_conv2d": ["NCHW", "default"], "nn.max_pool2d": ["NHWC"]} + ), + ) + # - in the before() func, its last argument "NHWC" is also the layout of max_pool + b = run_opt_pass( + # max_pool has its own layout argument + expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NHWC", "NCHW", max_pool_layout="NHWC"), + transform.InferType(), + ) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + + # NCHW -> NHWC + a = before(1, 3, 224, 224, 32, 3, 3, "NCHW") + a = run_opt_pass( + a, + transform.ConvertLayout( + {"nn.deformable_conv2d": ["NHWC", "default"], "nn.max_pool2d": ["NCHW"]} + ), + ) + # - in the before() func, its last argument "NCHW" is also the layout of max_pool + b = run_opt_pass( + # max_pool has its own layout argument + expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NCHW", "NHWC", max_pool_layout="NCHW"), + transform.InferType(), + ) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + + def test_dual_path_convert_layout(): def before(): x = relay.var("x", shape=(1, 56, 56, 64)) @@ -702,6 +890,57 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_resnet_pool_uses_specified_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var("weight1", shape=(3, 3, 64, 32)) + weight2 = relay.var("weight2", shape=(1, 1, 64, 32)) + y = relay.nn.conv2d( + x, + weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.relu(y) + y2 = relay.nn.conv2d( + x, weight2, channels=32, kernel_size=(1, 1), data_layout="NHWC", kernel_layout="HWIO" + ) + y2 = relay.nn.relu(y2) + y = y + y2 + y = relay.nn.global_max_pool2d(y, layout="NHWC") + return relay.Function(analysis.free_vars(y), y) + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var("weight1", shape=(3, 3, 64, 32)) + weight2 = relay.var("weight2", shape=(1, 1, 64, 32)) + weight1 = relay.layout_transform(weight1, "HWIO", "OIHW") + weight2 = relay.layout_transform(weight2, "HWIO", "OIHW") + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.relu(y) + y2 = relay.nn.conv2d(x, weight2, channels=32, kernel_size=(1, 1)) + y2 = relay.nn.relu(y2) + y = y + y2 + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.nn.global_max_pool2d(y, layout="NHWC", out_layout="NHWC") + return relay.Function(analysis.free_vars(y), y) + + a = before() + a = run_opt_pass( + a, + transform.ConvertLayout( + {"nn.conv2d": ["NCHW", "default"], "nn.global_max_pool2d": ["NHWC"]} + ), + ) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + + def test_scalar_convert_layout(): def before(): x = relay.var("x", shape=(1, 56, 56, 64)) @@ -2039,5 +2278,54 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_conv_max_pool_uses_specified_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NCHW") + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + x = relay.layout_transform(x, "NCHW", "NHWC") + weight = relay.layout_transform(weight, "OIHW", "OHWI") + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="OHWI", + ) + y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NHWC", out_layout="NHWC") + y = relay.layout_transform(y, "NHWC", "NCHW") + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass( + a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "OHWI"], "nn.max_pool2d": ["NHWC"]}) + ) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + + if __name__ == "__main__": pytest.main([__file__]) From edda830af0bb50edf03477882f9dee4888992dad Mon Sep 17 00:00:00 2001 From: Leo-arm <52416576+Leo-arm@users.noreply.github.com> Date: Fri, 22 Oct 2021 11:15:50 +0100 Subject: [PATCH 67/84] [ETHOSN] Match config for is-supported with compilation target (#9160) The Ethos-N variant configuration for the is-supported functionality is now the same as the variant configuration for the actual compilation --- src/relay/backend/contrib/ethosn/codegen.cc | 73 +++++++++++++------ .../backend/contrib/ethosn/codegen_ethosn.h | 18 +++++ 2 files changed, 69 insertions(+), 22 deletions(-) diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index 97b308e51e18..3e675215e7e0 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -606,25 +606,37 @@ std::pair, std::vector> EthosnCompiler::GetInput return std::make_pair(input_order, output_order); } -auto ctx = transform::PassContext::Current(); -auto cfg = ctx -> GetConfig("relay.ext.ethos-n.options").defined() - ? ctx -> GetConfig("relay.ext.ethos-n.options") - : AttrsWithDefaultValues(); -auto m_Queries = sl::SupportQueries(sl::GetFwAndHwCapabilities( - sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); +std::unique_ptr EthosnCompiler::m_Queries; + +EthosnError EthosnCompiler::SupportedSetup() { + if (m_Queries == nullptr) { + auto ctx = transform::PassContext::Current(); + auto cfg = ctx->GetConfig("relay.ext.ethos-n.options").defined() + ? ctx->GetConfig("relay.ext.ethos-n.options") + : AttrsWithDefaultValues(); + m_Queries = std::make_unique(sl::GetFwAndHwCapabilities( + sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); + if (m_Queries == nullptr) { + return EthosnError("Could not initialise Ethos-N compiler isSupported"); + } + } + return EthosnError(); +} TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; ConvolutionParams params; auto err = EthosnAPI::QnnConv2d(call, ¶ms); + err += EthosnCompiler::SupportedSetup(); if (params.is_depthwise) { *rv = !err && - m_Queries.IsDepthwiseConvolutionSupported(params.bias_info, params.weights_info, - params.conv_info, params.activation_info); + EthosnCompiler::GetSupported()->IsDepthwiseConvolutionSupported( + params.bias_info, params.weights_info, params.conv_info, params.activation_info); } else { - *rv = !err && m_Queries.IsConvolutionSupported(params.bias_info, params.weights_info, - params.conv_info, params.activation_info); + *rv = !err && + EthosnCompiler::GetSupported()->IsConvolutionSupported( + params.bias_info, params.weights_info, params.conv_info, params.activation_info); } }); @@ -633,8 +645,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc") Call call = args[0]; FullyConnectedParams params; auto err = EthosnAPI::QnnFullyConnected(call, ¶ms); - *rv = !err && m_Queries.IsFullyConnectedSupported(params.bias_info, params.weights_info, - params.fc_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsFullyConnectedSupported( + params.bias_info, params.weights_info, params.fc_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") @@ -642,7 +655,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") Call call = args[0]; MaxPool2DParams params; auto err = EthosnAPI::MaxPool2D(call, ¶ms); - *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsPoolingSupported(params.pool_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") @@ -650,7 +665,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") Call call = args[0]; AvgPool2DParams params; auto err = EthosnAPI::AvgPool2D(call, ¶ms); - *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsPoolingSupported(params.pool_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") @@ -658,7 +675,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") Call call = args[0]; ReshapeParams params; auto err = EthosnAPI::Reshape(call, ¶ms); - *rv = !err && m_Queries.IsReshapeSupported(params.new_shape, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsReshapeSupported(params.new_shape, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") @@ -666,8 +685,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") Call call = args[0]; AdditionParams params; auto err = EthosnAPI::Addition(call, ¶ms); - *rv = !err && m_Queries.IsAdditionSupported(params.lhs_info, params.rhs_info, - params.output_quantization_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsAdditionSupported( + params.lhs_info, params.rhs_info, params.output_quantization_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") @@ -675,7 +695,8 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") Call call = args[0]; SigmoidParams params; auto err = EthosnAPI::Sigmoid(call, ¶ms); - *rv = !err && m_Queries.IsSigmoidSupported(params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsSigmoidSupported(params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") @@ -683,7 +704,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") Call call = args[0]; ConcatenateParams params; auto err = EthosnAPI::Concatenate(call, ¶ms); - *rv = !err && m_Queries.IsConcatenationSupported(params.input_infos, params.concat_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsConcatenationSupported(params.input_infos, + params.concat_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") @@ -691,7 +714,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") Call call = args[0]; SplitParams params; auto err = EthosnAPI::Split(call, ¶ms); - *rv = !err && m_Queries.IsSplitSupported(params.input_info, params.split_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsSplitSupported(params.input_info, params.split_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") @@ -699,7 +724,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") Call call = args[0]; DepthToSpaceParams params; auto err = EthosnAPI::DepthToSpace(call, ¶ms); - *rv = !err && m_Queries.IsDepthToSpaceSupported(params.input_info, params.depth_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsDepthToSpaceSupported(params.input_info, + params.depth_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") @@ -707,7 +734,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") Call call = args[0]; ReluParams params; auto err = EthosnAPI::Relu(call, ¶ms); - *rv = !err && m_Queries.IsReluSupported(params.relu_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsReluSupported(params.relu_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.query").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index 63ae7a3e4704..ca2df05e958d 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -287,6 +287,22 @@ class EthosnCompiler { */ static runtime::Module CreateRuntimeModule(const ObjectRef& ref); + /*! + * \brief Initialise the is-supported functionality of the Ethos-N support library + * with the target variant. + * \return Error object + */ + static EthosnError SupportedSetup(); + + /*! + * \brief Return the is-supported API of the Support Library + * \return A reference to the API. + */ + static std::unique_ptr& GetSupported() { + ICHECK(m_Queries != nullptr); + return m_Queries; + } + private: /*! * \brief Compile a single Relay Ethos-N function into an ordered compiled network. @@ -322,6 +338,8 @@ class EthosnCompiler { */ static std::pair, std::vector> GetInputOutputOrder( NetworkWithIDs network, const std::unique_ptr& compiled_network); + + static std::unique_ptr m_Queries; }; runtime::Module CompileEthosn(const ObjectRef& ref) { From cec6ebb9edfb996f8f2f47cc8b9a24c18950dc85 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 22 Oct 2021 06:45:19 -0700 Subject: [PATCH 68/84] [Community] @ganler -> Reviewer (#9346) --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 19287b4cbfd5..6c63793fa217 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -118,6 +118,7 @@ We do encourage everyone to work anything they are interested in. - [Trevor Morris](https://github.com/trevor-m): @trevor-m - [Tatsuya Nishiyama](https://github.com/nishi-t): @nishi-t - [Leandro Nunes](https://github.com/leandron): @leandron +- [Jiawei Liu](https://github.com/ganler): @ganler - [Lily Orth-Smith](https://github.com/electriclilies): @electriclilies - [Wei Pan](https://github.com/wpan11nv): @wpan11nv - [Krzysztof Parzyszek](https://github.com/kparzysz-quic): @kparzysz-quic From 5ca646b5a8da811bd110aaecf7dbeb827b36f345 Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Fri, 22 Oct 2021 06:45:28 -0700 Subject: [PATCH 69/84] BUG: Look through on_device annotations when looking for shape constants (#9345) https://github.com/apache/tvm/pull/8788 introduced a perf regression since a `shape.as` in `alloc_tensor` was always failing due to the extra `on_device` annotation on the constant. Fixed that, and introduced some helpers to make this situation easier to deal with. (This is CORE-102 in OctoML JIRA). (Second try -- test_crp.py failure seems unrelated) --- src/relay/backend/aot_executor_codegen.cc | 3 +-- src/relay/backend/graph_plan_memory.cc | 5 ++--- src/relay/backend/vm/compiler.cc | 7 +++--- src/relay/op/annotation/annotation.h | 26 +++++++++++++++++++++++ src/relay/op/memory/memory.cc | 10 +++------ src/relay/transforms/pass_utils.h | 5 ++--- tests/python/relay/test_vm.py | 15 ++++++++++++- 7 files changed, 52 insertions(+), 19 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 56e008a345de..3c9c35c4f254 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -182,9 +182,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { * \return The corresponding token. */ StorageInfo GetStorage(const Expr& expr) { - auto props = GetOnDeviceProps(expr); // See through "on_device" calls. - Expr true_expr = props.body.defined() ? props.body : expr; + Expr true_expr = IgnoreOnDevice(expr); VisitExpr(true_expr); auto it = storage_device_map_.find(true_expr); ICHECK(it != storage_device_map_.end()); diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 7642f3ccf703..961252a14fa7 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -146,10 +146,9 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { * \return The corresponding token. */ const std::vector& GetToken(const Expr& expr) { - this->VisitExpr(expr); // See through on_device calls. - auto props = GetOnDeviceProps(expr); - Expr real_expr = props.body.defined() ? props.body : expr; + Expr real_expr = IgnoreOnDevice(expr); + this->VisitExpr(real_expr); auto it = token_map_.find(real_expr.get()); ICHECK(it != token_map_.end()) << "Expression not found in storage map:" << std::endl << PrettyPrint(real_expr); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 70ad2ccc992e..b3c1cd81274f 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -594,8 +594,9 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { auto offset_register = last_register_; // If the shape is constant then we will emit a static tensor allocation - // instruction. - auto const_shape = args[2].as(); + // instruction. It may be wrapped by an on_device, but it will be on the host + // which is assumed by the alloc_tensor instruction anyway. + auto const_shape = AsIgnoringOnDevice(args[2]); if (const_shape) { NDArray shape = const_shape->data; @@ -619,7 +620,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { this->VisitExpr(args[0]); auto size_register = last_register_; - ICHECK(args[1].as()); + ICHECK(args[1].as()); // Always a literal. NDArray alignment_arr = args[1].as()->data; ICHECK_EQ(alignment_arr->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got " diff --git a/src/relay/op/annotation/annotation.h b/src/relay/op/annotation/annotation.h index b6dff8813fd4..d772df9b023a 100644 --- a/src/relay/op/annotation/annotation.h +++ b/src/relay/op/annotation/annotation.h @@ -85,6 +85,32 @@ OnDeviceProps GetOnDeviceProps(const CallNode* call_node); */ OnDeviceProps GetOnDeviceProps(const Expr& expr); +/*! + * \brief Returns the body of \p expr if it is an "on_device" annotation, otherwise returns + * \p expr directly. + */ +inline Expr IgnoreOnDevice(const Expr& expr) { + OnDeviceProps props = GetOnDeviceProps(expr); + return props.body.defined() ? props.body : expr; +} + +/*! + * \brief Returns \p expr as \p NodeType, or null if it is not of that type. Looks through + * any "on_device" annotations. + */ +template +const NodeType* AsIgnoringOnDevice(const Expr& expr) { + const auto* node = expr.as(); + if (node != nullptr) { + return node; + } + OnDeviceProps props = GetOnDeviceProps(expr); + if (!props.body.defined()) { + return nullptr; + } + return props.body.as(); +} + /*! * \brief Returns \p function annotated with "param_device_types" and "result_device_type" * attributes capturing parameter and result devices types respectively. diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 6b22cfd6bdba..08e92b31965e 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -101,13 +101,9 @@ Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype, attrs->assert_shape = assert_shape; } else { // Look through any on_device for the shape argument expression. - Expr literal_shape = shape; - auto props = GetOnDeviceProps(literal_shape); - if (props.body.defined()) { - // See through on_device calls. - literal_shape = props.body; - } - attrs->const_shape = Downcast(literal_shape); + const auto* constant_node = AsIgnoringOnDevice(shape); + ICHECK(constant_node); + attrs->const_shape = GetRef(constant_node); } static const Op& op = Op::Get("memory.alloc_tensor"); return Call(op, {storage, offset, shape}, Attrs(attrs), {}); diff --git a/src/relay/transforms/pass_utils.h b/src/relay/transforms/pass_utils.h index ed9409856871..fd7f0a5594c2 100644 --- a/src/relay/transforms/pass_utils.h +++ b/src/relay/transforms/pass_utils.h @@ -118,9 +118,8 @@ inline Expr TransformF(const std::function& func, const Expr& * is it atomic? * if so, the compute cost of the expression is bounded so it can be copy without graph mode. */ -inline bool IsAtomic(const Expr& e) { - auto props = GetOnDeviceProps(e); - Expr true_expr = props.body.defined() ? props.body : e; +inline bool IsAtomic(const Expr& expr) { + Expr true_expr = IgnoreOnDevice(expr); return true_expr.as() || true_expr.as() || true_expr.as() || true_expr.as() || true_expr.as(); // Constant is always by reference. diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 42fe1a3cef3a..8ec41523f9dc 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -766,6 +766,19 @@ def test_vm_reshape_tensor(target, dev): check_result(target, dev, [x_np, y_np], x_np.reshape([8, 2, 8]), mod) +def test_vm_reshape_and_copy(target, dev): + """Make sure the compiler notices the reshape result shape is a literal and can use + the immediate-mode alloc_tensor instruction instead of alloc_tensor_reg.""" + x_np = np.random.uniform(size=(1, 1)).astype("float32") + x = relay.var("x", shape=(1, 1), dtype="float32") + mod = tvm.IRModule.from_expr(relay.Function([x], relay.copy(relay.reshape(x, [0, 1])))) + with tvm.transform.PassContext(opt_level=3): + exec = relay.vm.compile(mod, "llvm") + assert "alloc_tensor" in exec.bytecode + assert not "alloc_tensor_reg" in exec.bytecode + check_result(target, dev, [x_np], x_np.reshape([1, 1]), mod) + + def test_vm_reshape_tuple(target, dev, x_shape=(1, 4, 2), y_shape=(1, 2, 10)): tup = relay.var( "tup", @@ -963,4 +976,4 @@ def test_benchmark_end_to_end_rpc(): if __name__ == "__main__": import sys - sys.exit(pytest.main(sys.argv)) + sys.exit(pytest.main([__file__] + sys.argv[1:])) From d5dd8c019e342e849a9bc716d53bcae3fdbe9f9d Mon Sep 17 00:00:00 2001 From: masahi Date: Sat, 23 Oct 2021 00:59:00 +0900 Subject: [PATCH 70/84] Disable Hexagon TestConv2dPackedFilter test (#9344) --- tests/python/contrib/test_hexagon/test_conv2d_blocked.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py b/tests/python/contrib/test_hexagon/test_conv2d_blocked.py index 1304d341eda2..07696b51a327 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py +++ b/tests/python/contrib/test_hexagon/test_conv2d_blocked.py @@ -546,6 +546,7 @@ class TestConv2dPackedFilter(BaseConv2d): conv2d_impl = tvm.testing.parameter(conv2d_packed_filter, conv2d_packed_filter_nhwhwc) @tvm.testing.parametrize_targets("llvm") + @pytest.mark.skip("Skip due to being flaky on i386.") def test_conv2d( self, conv2d_impl, From 982e8e525bad711514262f30a1b1792568dd6f0e Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 22 Oct 2021 13:33:06 -0500 Subject: [PATCH 71/84] [Hexagon] Fix cmake files for Hexagon launcher (#9343) * [Hexagon] Fix cmake files for Hexagon launcher The update to build the launcher automatically accidentally broke building it separately. This patch fixes that. Also included are a few minor fixes and an update to the README.md. * Clarify support for Hexagon codegen --- apps/hexagon_launcher/README.md | 89 ++++++++++--------- .../cmake/HexagonLauncher.cmake | 5 -- .../cmake/android/CMakeLists.txt | 21 ++--- .../cmake/hexagon/CMakeLists.txt | 21 +++-- cmake/modules/Hexagon.cmake | 17 ++-- 5 files changed, 77 insertions(+), 76 deletions(-) diff --git a/apps/hexagon_launcher/README.md b/apps/hexagon_launcher/README.md index 85e6897b74a3..b190dd81a7b2 100644 --- a/apps/hexagon_launcher/README.md +++ b/apps/hexagon_launcher/README.md @@ -40,29 +40,33 @@ tvm_runtime, as well as the Hexagon launcher shared library and its correspondin tvm_runtime. As described in the [Manual compilation](#Manual compilation) section each component requires Hexagon and android dependencies. When building the launcher along with TVM these configurations must be providing when invoking cmake. A minimal -example invocation for compiling TVM along with the Hexagon launcher is included below, +example invocation for compiling TVM along with the Hexagon launcher is included below: ``` -cmake -DCMAKE_MAKE_PROGRAM=make \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ +cmake -DCMAKE_C_COMPILER=/path/to/clang \ + -DCMAKE_CXX_COMPILER=/path/to/clang++ \ -DCMAKE_CXX_FLAGS='-stdlib=libc++' \ -DCMAKE_CXX_STANDARD=14 \ - -DUSE_LLVM=/path/to/hexagon/llvm/bin/llvm-config \ + -DUSE_LLVM=/path/to/llvm/bin/llvm-config \ + -DUSE_HEXAGON_ARCH=v65|v66|v68 \ -DUSE_HEXAGON_LAUNCHER=ON \ - -DUSE_ANDROID_TOOLCHAIN=/path/to/android-ndk/build/cmake/android.toolchain.cmake \ - -DANDROID_PLATFORM=android-28 \ - -DANDROID_ABI=arm64-v8a \ - -DUSE_HEXAGON_ARCH=v68 \ -DUSE_HEXAGON_SDK=/path/to/hexagon/SDK \ - -DUSE_HEXAGON_TOOLCHAIN=/path/to/hexagon/Toolchain/ .. + -DUSE_HEXAGON_TOOLCHAIN=/path/to/hexagon/toolchain/ .. + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DUSE_ANDROID_TOOLCHAIN=/path/to/android-ndk/build/cmake/android.toolchain.cmake \ + .. ``` +where `v65|v66|v68` means "one of" these architecture versions. The Hexagon launcher application is an android binary and thus requires the use of an android toolchain for compilation. Similarly, the Hexagon tvm runtime requires the use of the Hexagon toolchain and depends on the Hexagon SDK. The -resulting hexagon launcher binaries can be found in the `launcher` subdirectory -of the cmake build directory. +resulting hexagon launcher binaries can be found in the `apps_hexagon_launcher` +subdirectory of the cmake build directory. Please note that the above command +will not build support for Hexagon codegen in the TVM library, for that please +additionally define the `USE_HEXAGON_DEVICE` variable. Also, the LLVM used in +`USE_LLVM` should have Hexagon target built in. ### Manual compilation @@ -72,43 +76,44 @@ code first. #### Compilation of the Hexagon part -1. Build the static version of TVM runtime for Hexagon. Use Hexagon clang - from the Hexagon SDK. This step is the same as building the shared version, - except at the cmake step, add `-DBUILD_STATIC_RUNTIME=ON`. The compilation - step should create `libtvm_runtime.a`. - -2. Create a subdirectory for the build files, and run `cmake` with the - following variables set: - - `FASTRPC_LIBS=SKEL` - - `USE_HEXAGON_SDK` to the path to the Hexagon SDK - - `CMAKE_C_COMPILER=hexagon-clang` - - `CMAKE_CXX_COMPILER=hexagon-clang++` - - `USE_HEXAGON_ARCH` to one of v65, v66, v68 - - `TVM_RUNTIME_HEXAGON=/path/to/libtvm_runtime.a` _statically_ linked - TVM runtime +Create a subdirectory for the build files, and run `cmake` with the +following variables set: - Make sure to provide the path to launcher's `CMakeLists.txt` directory - in `cmake` invocation. +``` +cmake -DCMAKE_C_COMPILER=/path/to/hexagon-clang \ + -DCMAKE_CXX_COMPILER=/path/to/hexagon-clang++ \ + -DUSE_HEXAGON_ARCH=v65|v66|v68 \ + -DUSE_HEXAGON_SDK=/path/to/hexagon/SDK \ + /path/to/apps/hexagon_launcher/cmake/hexagon +``` -3. Run `make`. This will create `liblauncher_rpc_skel.so`. +Run `make`. This will create `liblauncher_rpc_skel.so`. The static version of +the TVM runtime for Hexagon will be built as a part of the process. #### Compilation of the Android part -1. Build TVM runtime for Android, using clang for AArch64 from the Android - NDK. Unlike in the Hexagon case, this should be the dynamic library (which - is the default), i.e. `libtvm_runtime.so`. - 2. Create a subdirectory for the build files (different from the one used for Hexagon files), and run `cmake` with the following variables set: - - `FASTRPC_LIBS=STUB` - - `USE_HEXAGON_SDK` to the path to the Hexagon SDK - - `CMAKE_C_COMPILER=aarch64-linux-android28-clang` (or later) - - `CMAKE_CXX_COMPILER=aarch64-linux-android28-clang++` (or later) - - `USE_HEXAGON_ARCH` to one of v65, v66, v68 (same as for the Hexagon part) - - `TVM_RUNTIME_ANDROID=/path/to/libtvm_runtime.so` dynamically or - statically linked TVM runtime - -3. Run `make`. This will create `launcher_android`. + +``` +cmake -DCMAKE_TOOLCHAIN_FILE=/path/to/android-ndk/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DUSE_HEXAGON_SDK=/p/Hexagon_SDK/4.3.0.0 + -DUSE_HEXAGON_ARCH=v65|v66|v68 + /path/to/apps/hexagon_launcher/cmake/android +``` + +Run `make`. This will create `launcher_android`. The TVM runtime for Android will +be built as a part of the process. Depending on the version of cmake that you are +using, you may see the following warnings---they can be ignored. + +``` +An old version of CMake is being used that cannot automatically detect +compiler attributes. Compiler identification is being bypassed. Some +values may be wrong or missing. Update to CMake 3.19 or newer to use +CMake's built-in compiler identification. +``` ## Execution diff --git a/apps/hexagon_launcher/cmake/HexagonLauncher.cmake b/apps/hexagon_launcher/cmake/HexagonLauncher.cmake index 4a7f803ce1ab..abf877cb67f1 100644 --- a/apps/hexagon_launcher/cmake/HexagonLauncher.cmake +++ b/apps/hexagon_launcher/cmake/HexagonLauncher.cmake @@ -15,11 +15,6 @@ # specific language governing permissions and limitations # under the License. -if(NOT "${FASTRPC_LIBS}" STREQUAL "SKEL" AND - NOT "${FASTRPC_LIBS}" STREQUAL "STUB") - message(SEND_ERROR "Please set FASTRPC_LIBS to either SKEL or STUB") -endif() - if(NOT DEFINED USE_HEXAGON_SDK) message(SEND_ERROR "Please set USE_HEXAGON_SDK to the location of Hexagon SDK") endif() diff --git a/apps/hexagon_launcher/cmake/android/CMakeLists.txt b/apps/hexagon_launcher/cmake/android/CMakeLists.txt index c000b0e97cad..7716cde99863 100644 --- a/apps/hexagon_launcher/cmake/android/CMakeLists.txt +++ b/apps/hexagon_launcher/cmake/android/CMakeLists.txt @@ -21,17 +21,15 @@ project(HexagonAndroidLauncher C CXX) include("${CMAKE_CURRENT_SOURCE_DIR}/../HexagonLauncher.cmake") add_custom_command( - OUTPUT ${LAUNCHER_RPC_STUB_C} - COMMAND ${QAIC_EXE} ${QAIC_FLAGS} - "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" - COMMAND ${CMAKE_COMMAND} -E rename "${LAUNCHER_RPC_H}" - "${LAUNCHER_SRC}/${LAUNCHER_RPC_H}" + OUTPUT ${LAUNCHER_RPC_STUB_C} ${LAUNCHER_RPC_H} + COMMAND ${QAIC_EXE} ${QAIC_FLAGS} "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" MAIN_DEPENDENCY "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" ) include_directories(SYSTEM "${HEXAGON_SDK_INCLUDES}" "${HEXAGON_RPCMEM_ROOT}/inc" + "${CMAKE_CURRENT_BINARY_DIR}" # Output of qaic will go here ) link_directories(${HEXAGON_REMOTE_ROOT}) @@ -46,8 +44,9 @@ set(STUB_SRCS ) add_executable(launcher_android - "${STUB_SRCS}" + "${LAUNCHER_RPC_H}" "${LAUNCHER_RPC_STUB_C}" + "${STUB_SRCS}" ) ExternalProject_Add(android_tvm_runtime @@ -66,12 +65,14 @@ ExternalProject_Add(android_tvm_runtime ) ExternalProject_Get_Property(android_tvm_runtime BINARY_DIR) ExternalProject_Add_Step(android_tvm_runtime copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_directory ${BINARY_DIR} ${CMAKE_INSTALL_PREFIX} + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${BINARY_DIR}/libtvm_runtime.so + ${CMAKE_CURRENT_BINARY_DIR} DEPENDEES install ) add_dependencies(launcher_android android_tvm_runtime) -add_library(tvm_runtime SHARED IMPORTED) -set_target_properties(tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.so") +add_library(a_tvm_runtime SHARED IMPORTED) +set_target_properties(a_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.so") -target_link_libraries(launcher_android cdsprpc log tvm_runtime) +target_link_libraries(launcher_android cdsprpc log a_tvm_runtime) diff --git a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt index c76fcccc5a1a..3f99459f3a49 100644 --- a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt +++ b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt @@ -22,12 +22,14 @@ include("${CMAKE_CURRENT_SOURCE_DIR}/../HexagonLauncher.cmake") add_custom_command( OUTPUT ${LAUNCHER_RPC_SKEL_C} ${LAUNCHER_RPC_H} - COMMAND ${QAIC_EXE} ${QAIC_FLAGS} - "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" + COMMAND ${QAIC_EXE} ${QAIC_FLAGS} "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" MAIN_DEPENDENCY "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" ) -include_directories(SYSTEM ${HEXAGON_QURT_INCLUDES}) +include_directories(SYSTEM + ${HEXAGON_QURT_INCLUDES} + ${CMAKE_CURRENT_BINARY_DIR} # Output of qaic will go here +) link_directories(${HEXAGON_QURT_LIBS}) @@ -48,8 +50,9 @@ set(SKEL_SRCS "${LAUNCHER_SRC}/launcher_core.cc" "${LAUNCHER_SRC}/launcher_hexagon.cc" ) + add_library(launcher_rpc_skel SHARED - "${LAUNCHER_SRC}/${LAUNCHER_RPC_H}" + "${LAUNCHER_RPC_H}" "${LAUNCHER_RPC_SKEL_C}" "${SKEL_SRCS}" ) @@ -71,14 +74,10 @@ ExternalProject_Add(static_hexagon_tvm_runtime BUILD_ALWAYS ON ) ExternalProject_Get_Property(static_hexagon_tvm_runtime BINARY_DIR) -ExternalProject_Add_Step(static_hexagon_tvm_runtime copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_directory ${BINARY_DIR} ${CMAKE_INSTALL_PREFIX} - DEPENDEES install -) add_dependencies(launcher_rpc_skel static_hexagon_tvm_runtime) -add_library(static_tvm_runtime STATIC IMPORTED) -set_target_properties(static_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.a") +add_library(h_tvm_runtime STATIC IMPORTED) +set_target_properties(h_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.a") -target_link_libraries(launcher_rpc_skel -Wl,--whole-archive static_tvm_runtime -Wl,--no-whole-archive) +target_link_libraries(launcher_rpc_skel -Wl,--whole-archive h_tvm_runtime -Wl,--no-whole-archive) diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index 1491a4558611..88623ab045fd 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -76,7 +76,6 @@ if(NOT USE_HEXAGON_SDK) endif() if(USE_HEXAGON_LAUNCHER STREQUAL "ON") - if(DEFINED USE_ANDROID_TOOLCHAIN) if(NOT DEFINED ANDROID_PLATFORM) message(SEND_ERROR "Please set ANDROID_PLATFORM " @@ -91,7 +90,7 @@ if(USE_HEXAGON_LAUNCHER STREQUAL "ON") " launcher for hexagon.") endif() - set(LAUNCHER_BINARY_DIR "${CMAKE_BINARY_DIR}/launcher") + set(LAUNCHER_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/apps_hexagon_launcher") ExternalProject_Add(launcher_android SOURCE_DIR "${CMAKE_SOURCE_DIR}/apps/hexagon_launcher/cmake/android" INSTALL_DIR "${LAUNCHER_BINARY_DIR}" @@ -101,14 +100,15 @@ if(USE_HEXAGON_LAUNCHER STREQUAL "ON") "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" "-DANDROID_ABI=${ANDROID_ABI}" "-DFASTRPC_LIBS=STUB" - "-DUSE_HEXAGON_ARCH=v68" + "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DCMAKE_INSTALL_PREFIX:PATH=" INSTALL_COMMAND "" ) ExternalProject_Get_Property(launcher_android BINARY_DIR) ExternalProject_Add_Step(launcher_android copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_directory ${BINARY_DIR} ${LAUNCHER_BINARY_DIR} + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${BINARY_DIR}/launcher_android ${BINARY_DIR}/libtvm_runtime.so + ${LAUNCHER_BINARY_DIR} DEPENDEES install ) ExternalProject_Add(launcher_hexagon @@ -119,14 +119,15 @@ if(USE_HEXAGON_LAUNCHER STREQUAL "ON") "-DCMAKE_C_COMPILER=${USE_HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang" "-DCMAKE_CXX_COMPILER=${USE_HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++" "-DFASTRPC_LIBS=SKEL" - "-DUSE_HEXAGON_ARCH=v68" + "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DCMAKE_INSTALL_PREFIX:PATH=" INSTALL_COMMAND "" ) ExternalProject_Get_Property(launcher_hexagon BINARY_DIR) ExternalProject_Add_Step(launcher_hexagon copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_directory ${BINARY_DIR} ${LAUNCHER_BINARY_DIR} + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${BINARY_DIR}/liblauncher_rpc_skel.so + ${LAUNCHER_BINARY_DIR} DEPENDEES install ) From d34a632db56f65b8298215faed52dadba9d189f1 Mon Sep 17 00:00:00 2001 From: masahi Date: Sat, 23 Oct 2021 05:37:10 +0900 Subject: [PATCH 72/84] Support dynamic shape searchsorted (#9348) --- python/tvm/relay/op/_algorithm.py | 25 +++++++++++++++++++++++++ tests/python/relay/test_any.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 19162a108395..dd1a65288955 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -84,3 +84,28 @@ def topk_shape_func(attrs, inputs, _): ret = [indices_out] return ret + + +@script +def _searchsorted_shape(sorted_sequence_shape, values_shape): + out_shape = output_tensor((values_shape.shape[0],), "int64") + if sorted_sequence_shape.shape[0] > 1: + assert ( + sorted_sequence_shape.shape[0] == values_shape.shape[0] + ), "Ranks of `sorted_sequence` and values must be the same if `sorted_sequence` is not 1-D." + for i in range(values_shape.shape[0]): + if sorted_sequence_shape.shape[0] > 1 and i < values_shape.shape[0] - 1: + assert ( + sorted_sequence_shape[i] == values_shape[i] + ), "`sorted_sequence and `values` do not have the same shape along outer axes." + + out_shape[i] = values_shape[i] + return out_shape + + +@_reg.register_shape_func("searchsorted", False) +def searchsorted_shape_func(attrs, inputs, _): + """ + Shape func for searchsorted operator. + """ + return [_searchsorted_shape(inputs[0], inputs[1])] diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 8788faf45866..f42f7ad7ca69 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -23,6 +23,7 @@ from tvm import relay, te from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type +from tvm.topi.testing import searchsorted_ref from utils import ref_funcs from utils.assert_diagnostic import DiagnosticTesting @@ -2086,5 +2087,35 @@ def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, ax verify_gather((relay.Any(), relay.Any()), (relay.Any(), relay.Any()), (2, 3), (1, 3), 0) +@tvm.testing.uses_gpu +def test_searchsorted(): + def verify_searchsorted( + sorted_sequence_shape, values_shape, sorted_sequence_shape_np, values_shape_np + ): + x = relay.var("x", relay.TensorType(sorted_sequence_shape, "float32")) + y = relay.var("y", relay.TensorType(values_shape, "float32")) + z = relay.searchsorted(x, y) + + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + + x_np = np.sort(np.random.uniform(size=sorted_sequence_shape_np).astype("float32"), axis=-1) + y_np = np.random.uniform(size=values_shape_np).astype("float32") + + ref_res = searchsorted_ref(x_np, y_np, False, "int32") + check_result([x_np, y_np], mod, [ref_res]) + + for shape_np, values_shape_np in zip([(8, 9, 10), (10,), (11,)], [(8, 9, 20), (5,), (8, 9, 7)]): + sorted_sequence_shape = (relay.Any(),) * len(shape_np) + values_shape = (relay.Any(),) * len(values_shape_np) + + verify_searchsorted( + sorted_sequence_shape, + values_shape, + shape_np, + values_shape_np, + ) + + if __name__ == "__main__": pytest.main([__file__]) From f0efecc8e19ecdc0a322bb8fc234e7b527f8b654 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Fri, 22 Oct 2021 13:44:19 -0700 Subject: [PATCH 73/84] [microTVM][Zephyr] Enable RISCV Tests on QEMU CI (#9325) * add riscv32 * add riscv64 * fix url --- apps/microtvm/zephyr/template_project/boards.json | 2 +- .../zephyr/template_project/crt_config/crt_config.h | 2 +- .../zephyr/template_project/src/host_driven/main.c | 5 ----- tests/micro/zephyr/test_zephyr.py | 3 +++ tests/micro/zephyr/test_zephyr_aot.py | 6 ++++-- tests/scripts/task_python_microtvm.sh | 9 ++++++++- 6 files changed, 17 insertions(+), 10 deletions(-) diff --git a/apps/microtvm/zephyr/template_project/boards.json b/apps/microtvm/zephyr/template_project/boards.json index aabed3322150..18e393897f04 100644 --- a/apps/microtvm/zephyr/template_project/boards.json +++ b/apps/microtvm/zephyr/template_project/boards.json @@ -39,7 +39,7 @@ "board": "qemu_riscv32", "model": "host", "is_qemu": true, - "fpu": true + "fpu": false }, "qemu_riscv64": { "board": "qemu_riscv64", diff --git a/apps/microtvm/zephyr/template_project/crt_config/crt_config.h b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h index f8fc7514a28d..39fe27ef3d05 100644 --- a/apps/microtvm/zephyr/template_project/crt_config/crt_config.h +++ b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h @@ -36,7 +36,7 @@ #define TVM_CRT_MAX_ARGS 10 /*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200 +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/apps/microtvm/zephyr/template_project/src/host_driven/main.c b/apps/microtvm/zephyr/template_project/src/host_driven/main.c index 43064e804193..44d656028cbc 100644 --- a/apps/microtvm/zephyr/template_project/src/host_driven/main.c +++ b/apps/microtvm/zephyr/template_project/src/host_driven/main.c @@ -260,11 +260,6 @@ void uart_rx_init(struct ring_buf* rbuf, const struct device* dev) { // The main function of this application. extern void __stdout_hook_install(int (*hook)(int)); void main(void) { - // TODO (mehrdadh): Update this when zephyr version has updated to 2.6. - // Update zephyr to latest version to use with qemu_riscv32. -#ifdef CONFIG_BOARD_QEMU_RISCV32 - k_float_enable(_current, 0); -#endif #ifdef CONFIG_LED int ret; diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index be1f231156ad..089598007651 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -374,6 +374,9 @@ def test_tensors(sess): @tvm.testing.requires_micro def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): """Test AutoTune for microTVM Zephyr""" + if board in ["qemu_riscv32", "qemu_riscv64"]: + pytest.xfail(f"Autotune fails on {board}.") + model = test_utils.ZEPHYR_BOARDS[board] build_config = {"debug": tvm_debug} diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot.py index 5bc665b748f6..f79aa8bd70d2 100644 --- a/tests/micro/zephyr/test_zephyr_aot.py +++ b/tests/micro/zephyr/test_zephyr_aot.py @@ -47,6 +47,8 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug): "nrf5340dk_nrf5340_cpuapp", "nucleo_l4r5zi", "qemu_cortex_r5", + "qemu_riscv32", + "qemu_riscv64", ]: pytest.skip(msg="Model does not fit.") @@ -55,8 +57,8 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug): output_shape = (1, 10) build_config = {"debug": tvm_debug} - model_url = "https://github.com/eembc/ulpmark-ml/raw/fc1499c7cc83681a02820d5ddf5d97fe75d4f663/base_models/ic01/ic01_fp32.tflite" - model_path = download_testdata(model_url, "ic01_fp32.tflite", module="model") + model_url = "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/model/image_classification_fp32.tflite" + model_path = download_testdata(model_url, "image_classification_fp32.tflite", module="model") # Import TFLite model tflite_model_buf = open(model_path, "rb").read() diff --git a/tests/scripts/task_python_microtvm.sh b/tests/scripts/task_python_microtvm.sh index 6632ebb1ca52..8de8b908ee09 100755 --- a/tests/scripts/task_python_microtvm.sh +++ b/tests/scripts/task_python_microtvm.sh @@ -23,13 +23,20 @@ set -x # NOTE(areusch): Adding to diagnose flaky timeouts source tests/scripts/setup-pytest-env.sh make cython3 -run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --zephyr-board=qemu_x86 + +# Zephyr +run_pytest ctypes python-microtvm-zephyr-qemu_x86 tests/micro/zephyr --zephyr-board=qemu_x86 +run_pytest ctypes python-microtvm-zephyr-qemu_riscv32 tests/micro/zephyr --zephyr-board=qemu_riscv32 +run_pytest ctypes python-microtvm-zephyr-qemu_riscv64 tests/micro/zephyr --zephyr-board=qemu_riscv64 + # Temporarily removing mps2_an512 from CI due to issue 8728: # https://github.com/apache/tvm/issues/8728 # run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --zephyr-board=mps2_an521 +# Arduino run_pytest ctypes python-microtvm-arduino apps/microtvm/arduino/template_project/tests run_pytest ctypes python-microtvm-arduino-nano33ble tests/micro/arduino --test-build-only --arduino-board=nano33ble run_pytest ctypes python-microtvm-arduino-due tests/micro/arduino --test-build-only --arduino-board=due +# STM32 run_pytest ctypes python-microtvm-stm32 tests/micro/stm32 From 6ef1c2a4b028dd170a0afe6ebceb45a5325033d3 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Sat, 23 Oct 2021 00:16:05 +0300 Subject: [PATCH 74/84] [CORE][Relay] Swap and remove compile_engine with te_compiler followup of #8775 (#9282) * Remove compile_engine.h for real * Fix format * RM compile_engine.cc * Swap compile engine with TECompiler * Cleanup on compile engine py leftovers * [WIP] Exposing legacy compile engine capabilities through TE Compiler * Swap usages for depreciated compile engine with TE compiler * Track and replace usages of compile engine refactor them to TE compiler * [Docs] Log helper mod * Remove depreciated function for lookup compile engine cachce * Fix typos * Debug misc cleanups * Register global pass for using te compiler for auto scheduler * Fix tests using the legacy compile engine * Fix broken autotuner tests and minor cleanups * Swap compile engine with te_compiler in rst config * PR nits * Fix failed test Co-authored-by: Jared Roesch --- docs/arch/relay_op_strategy.rst | 8 +- docs/reference/api/python/relay/backend.rst | 2 +- .../tvm/auto_scheduler/relay_integration.py | 4 +- .../graph_tuner/utils/traverse_graph.py | 2 +- python/tvm/autotvm/task/relay_integration.py | 4 +- python/tvm/relay/backend/__init__.py | 2 +- .../{compile_engine.py => te_compiler.py} | 130 ++----- python/tvm/relay/testing/py_converter.py | 13 +- python/tvm/topi/arm_cpu/conv2d_alter_op.py | 2 +- python/tvm/topi/bifrost/conv2d.py | 2 +- python/tvm/topi/cuda/conv2d_alter_op.py | 2 +- python/tvm/topi/cuda/conv3d_alter_op.py | 2 +- .../topi/intel_graphics/conv2d_alter_op.py | 2 +- python/tvm/topi/mali/conv2d.py | 2 +- python/tvm/topi/x86/conv2d_alter_op.py | 2 +- python/tvm/topi/x86/dense_alter_op.py | 2 +- src/relay/backend/build_module.cc | 4 +- src/relay/backend/compile_engine.cc | 338 ------------------ src/relay/backend/compile_engine.h | 115 ------ src/relay/backend/interpreter.cc | 2 +- src/relay/backend/te_compiler.cc | 39 ++ src/relay/backend/te_compiler.h | 3 +- src/relay/backend/te_compiler_cache.h | 1 - src/relay/backend/utils.h | 9 - .../auto_scheduler_layout_rewrite.cc | 5 +- src/runtime/object.cc | 4 +- .../test_arm_compute_lib/infrastructure.py | 2 +- .../contrib/test_bnns/infrastructure.py | 2 +- .../contrib/test_ethosn/infrastructure.py | 4 +- .../contrib/test_vitis_ai/infrastructure.py | 2 +- tests/python/relay/aot/aot_test_utils.py | 5 +- .../relay/dyn/test_dynamic_op_level3.py | 5 +- tests/python/relay/test_json_runtime.py | 8 +- tests/python/relay/test_op_level3.py | 8 +- .../python/relay/test_pass_partition_graph.py | 10 +- ...le_engine.py => test_relay_te_compiler.py} | 26 +- .../test_tir_transform_narrow_datatype.py | 7 +- 37 files changed, 151 insertions(+), 629 deletions(-) rename python/tvm/relay/backend/{compile_engine.py => te_compiler.py} (79%) delete mode 100644 src/relay/backend/compile_engine.cc delete mode 100644 src/relay/backend/compile_engine.h rename tests/python/relay/{test_backend_compile_engine.py => test_relay_te_compiler.py} (93%) diff --git a/docs/arch/relay_op_strategy.rst b/docs/arch/relay_op_strategy.rst index c40251d22433..dbac7c821827 100644 --- a/docs/arch/relay_op_strategy.rst +++ b/docs/arch/relay_op_strategy.rst @@ -269,14 +269,14 @@ will then be chosen. Implementations with same priority level in this case leads to an undefined behavior, and any of them might be selected. The selection policy for ops with symbolic input shapes is still work in -progess. Currently, if any input tensor has a symbolic shape, only the +progress. Currently, if any input tensor has a symbolic shape, only the implementation with highest priority level will be used for this operator. This -will be updated after the implemention finishes. +will be updated after the implementation finishes. For debug purpose, you can add the following lines before you compile the Relay model to learn which implementation is used for each operator. .. code:: python - logging.getLogger("compile_engine").setLevel(logging.INFO) - logging.getLogger("compile_engine").addHandler(logging.StreamHandler(sys.stdout)) + logging.getLogger("te_compiler").setLevel(logging.INFO) + logging.getLogger("te_compiler").addHandler(logging.StreamHandler(sys.stdout)) diff --git a/docs/reference/api/python/relay/backend.rst b/docs/reference/api/python/relay/backend.rst index ffe8a9a8ce79..e717ee10ffab 100644 --- a/docs/reference/api/python/relay/backend.rst +++ b/docs/reference/api/python/relay/backend.rst @@ -23,7 +23,7 @@ tvm.relay.backend .. automodule:: tvm.relay.backend.interpreter :members: -.. automodule:: tvm.relay.backend.compile_engine +.. automodule:: tvm.relay.backend.te_compiler :members: .. automodule:: tvm.relay.backend.graph_executor_codegen diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 0eacd1a1f667..6f35e021daf8 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -58,7 +58,6 @@ def call_all_topi_funcs(mod, params, target, opt_level=3): opt_level=opt_level, config={ "relay.backend.use_auto_scheduler": True, - "relay.backend.disable_compile_engine_cache": True, }, disabled_pass={"AutoSchedulerLayoutRewrite"}, ): @@ -165,7 +164,8 @@ class TracingMode: """Two modes for tracing""" EXTRACT_TASK = 0 # trace all topi calls to extract tasks - EXTRACT_COMPLEX_TASK_ONLY = 1 # same as EXTRACT_TASK but ignore the task without complex ops + # same as EXTRACT_TASK but ignore the task without complex ops + EXTRACT_COMPLEX_TASK_ONLY = 1 PREPARE_LAYOUT_REWRITE = 2 # trace topi calls to prepare layout rewrite diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index 723e7fa77006..7299875bf28d 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -142,7 +142,7 @@ def _traverse_expr(node): params.append(free_var) call = relay.Call(node.op, params, node.attrs) mod = tvm.IRModule.from_expr(relay.Function(params, call)) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() tracing_target = _replace_device_with_tracing(tvm_target) build_thread = threading.Thread( target=relay.build, args=(mod, tracing_target, None, None) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 714dd540d3ab..4716116a1b83 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -127,12 +127,12 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No assert isinstance( mod, tvm.IRModule ), "only support relay Module or Function to be tuned" - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() # wrap build call in thread to avoid multiprocessing problems build_thread = threading.Thread(target=_lower, args=(mod, target, param)) build_thread.start() build_thread.join() - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() # Clear the warning message cache in FallbackContext if isinstance(DispatchContext.current, FallbackContext): DispatchContext.current.memory = {} diff --git a/python/tvm/relay/backend/__init__.py b/python/tvm/relay/backend/__init__.py index 4fc2b63748db..d76459236515 100644 --- a/python/tvm/relay/backend/__init__.py +++ b/python/tvm/relay/backend/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. """Backend codegen modules for relay.""" -from . import compile_engine +from . import te_compiler diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/te_compiler.py similarity index 79% rename from python/tvm/relay/backend/compile_engine.py rename to python/tvm/relay/backend/te_compiler.py index e9129db7b200..db7504915887 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/te_compiler.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=len-as-condition,no-else-return,invalid-name -"""Backend code generation engine.""" +"""TE compiler engine (replacing legacy compile_engine).""" from __future__ import absolute_import import logging -import numpy as np import tvm from tvm import te, autotvm from tvm.ir.transform import PassContext @@ -31,7 +30,7 @@ from .. import ty as _ty from . import _backend -logger = logging.getLogger("compile_engine") +logger = logging.getLogger("te_compiler") autotvm_logger = logging.getLogger("autotvm") _first_warning = True @@ -47,7 +46,7 @@ def __init__(self, outputs, implement): @tvm._ffi.register_object("relay.CCacheKey") class CCacheKey(Object): - """Key in the CompileEngine. + """Key in the TE Compiler. Parameters ---------- @@ -64,7 +63,7 @@ def __init__(self, source_func, target): @tvm._ffi.register_object("relay.CCacheValue") class CCacheValue(Object): - """Value in the CompileEngine, including usage statistics.""" + """Value in the TE Compiler, including usage statistics.""" def _get_cache_key(source_func, target): @@ -79,24 +78,6 @@ def _get_cache_key(source_func, target): return source_func -def get_shape(shape): - """Convert the shape to correct dtype and vars.""" - ret = [] - for dim in shape: - if isinstance(dim, tvm.tir.IntImm): - if libinfo()["INDEX_DEFAULT_I64"] == "ON": - ret.append(dim) - else: - val = int(dim) - assert val <= np.iinfo(np.int32).max - ret.append(tvm.tir.IntImm("int32", val)) - elif isinstance(dim, tvm.tir.Any): - ret.append(te.var("any_dim", "int32")) - else: - ret.append(dim) - return ret - - def get_valid_implementations(op, attrs, inputs, out_type, target): """Get all valid implementations from the op strategy. @@ -275,6 +256,24 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) return best_plevel_impl, outputs[best_plevel_impl] +def get_shape(shape): + """Convert the shape to correct dtype and vars.""" + ret = [] + for dim in shape: + if isinstance(dim, tvm.tir.IntImm): + if libinfo()["INDEX_DEFAULT_I64"] == "ON": + ret.append(dim) + else: + val = int(dim) + assert val <= np.iinfo(np.int32).max + ret.append(tvm.tir.IntImm("int32", val)) + elif isinstance(dim, tvm.tir.Any): + ret.append(te.var("any_dim", "int32")) + else: + ret.append(dim) + return ret + + @tvm._ffi.register_func("relay.backend.lower_call") def lower_call(call, inputs, target): """Lower the call expression to op implementation and tensor outputs.""" @@ -322,12 +321,12 @@ def lower_call(call, inputs, target): return LoweredOutput(outputs, best_impl) -@tvm._ffi.register_object("relay.CompileEngine") -class CompileEngine(Object): - """CompileEngine to get lowered code.""" +@tvm._ffi.register_object("relay.TECompiler") +class TECompiler(Object): + """TECompiler to get lowered code.""" def __init__(self): - raise RuntimeError("Cannot construct a CompileEngine") + raise RuntimeError("Cannot construct a TECompiler") def lower(self, source_func, target=None, mod_name="default"): """Lower a source_func to a CachedFunc. @@ -349,7 +348,7 @@ def lower(self, source_func, target=None, mod_name="default"): try: mod_name = mangle_module_name(mod_name) key = _get_cache_key(source_func, target) - return _backend._CompileEngineLower(self, key, mod_name) + return _backend._TECompilerLower(self, key, mod_name) except Exception: import traceback @@ -360,10 +359,6 @@ def lower(self, source_func, target=None, mod_name="default"): msg += "--------------------------\n" raise RuntimeError(msg) - def lower_shape_func(self, source_func, target=None): - key = _get_cache_key(source_func, target) - return _backend._CompileEngineLowerShapeFunc(self, key) - def jit(self, source_func, target=None): """JIT a source_func to a tvm.runtime.PackedFunc. @@ -381,87 +376,30 @@ def jit(self, source_func, target=None): The result of jited function. """ key = _get_cache_key(source_func, target) - return _backend._CompileEngineJIT(self, key) + return _backend._TECompilerJIT(self, key) def clear(self): """clear the existing cached functions""" - _backend._CompileEngineClear(self) + _backend._TECompilerClear(self) def items(self): """List items in the cache. - Returns ------- item_list : List[Tuple[CCacheKey, CCacheValue]] The list of items. """ - res = _backend._CompileEngineListItems(self) - assert len(res) % 2 == 0 - return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)] - - def shape_func_items(self): - """List items in the shape_func_cache. - - Returns - ------- - item_list : List[Tuple[CCacheKey, CCacheValue]] - The list of shape_func_items. - """ - res = _backend._CompileEngineListShapeFuncItems(self) + res = _backend._TECompilerListItems(self) assert len(res) % 2 == 0 return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)] - def get_current_ccache_key(self): - return _backend._CompileEngineGetCurrentCCacheKey(self) - - def dump(self): - """Return a string representation of engine dump. - - Returns - ------- - dump : str - The dumped string representation - """ - items = self.items() - res = "====================================\n" - res += "CompilerEngine dump, %d items cached\n" % len(items) - for k, v in items: - res += "------------------------------------\n" - res += "target={}\n".format(k.target) - res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) - res += "----relay function----\n" - res += k.source_func.astext() + "\n" - res += "----tir function----- \n" - res += "inputs={}\n".format(v.cached_func.inputs) - res += "outputs={}\n".format(v.cached_func.outputs) - res += "function: \n" - res += v.cached_func.funcs.astext() + "\n" - res += "===================================\n" - shape_func_items = self.shape_func_items() - res += "%d shape_func_items cached\n" % len(shape_func_items) - for k, v in shape_func_items: - res += "------------------------------------\n" - res += "target={}\n".format(k.target) - res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) - res += "----relay function----\n" - res += k.source_func.astext() + "\n" - res += "----tir function----- \n" - res += "inputs={}\n".format(v.cached_func.inputs) - res += "outputs={}\n".format(v.cached_func.outputs) - res += "function: \n" - res += v.cached_func.funcs.astext() + "\n" - res += "===================================\n" - return res - def get(): - """Get the global compile engine. + """Get the global TE Compiler. Returns ------- - engine : tvm.relay.backend.CompileEngine - The compile engine. + engine : tvm.relay.backend.TECompiler + The TE Compiler. """ - return _backend._CompileEngineGlobal() + return _backend._TECompilerGlobal() diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index b9d6806306f4..50f473aea1f2 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -24,7 +24,7 @@ import tvm from tvm import relay from tvm.relay.adt import Pattern -from tvm.relay.backend import compile_engine +from tvm.relay.backend import te_compiler from tvm.relay.expr import Expr, GlobalVar, Var from tvm.relay.function import Function from tvm.relay.expr_functor import ExprFunctor @@ -61,7 +61,7 @@ def __init__(self, mod, target) -> None: super().__init__() self.mod = mod self.tgt = target - self.engine = compile_engine.get() + self.tec = te_compiler.get() self.fun_no = 0 self.var_no = 0 self.var_map = {} @@ -153,7 +153,10 @@ def parse_name(self, name: str): def parse_numpy_array(self, arr): """Given a Numpy array, produces an appropriate Python array or numerical literal representing its contents.""" - parse_single = lambda i: NameConstant(i) if isinstance(i, bool) else Num(i) + + def parse_single(i): + return NameConstant(i) if isinstance(i, bool) else Num(i) + if arr.ndim == 0: return parse_single(arr.item()) if arr.ndim == 1: @@ -240,11 +243,11 @@ def create_op_call(self, op: Function, relay_args, py_args): the generated Python code.""" # compile the function and register globally - cc_key = compile_engine.CCacheKey(op, self.tgt) + cc_key = te_compiler.CCacheKey(op, self.tgt) func_hash = tvm.ir.structural_hash(op) op_name = "_lowered_op_{}".format(func_hash) if not tvm.get_global_func(op_name, allow_missing=True): - jitted = self.engine.jit(cc_key, self.tgt) + jitted = self.tec.jit(cc_key, self.tgt) tvm.register_func(op_name, jitted) def convert_input(py_input, arg_type): diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index c7c572c81110..cbe8644c885f 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -90,7 +90,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/bifrost/conv2d.py b/python/tvm/topi/bifrost/conv2d.py index 3b6cca6aaea4..633f36c0e7ff 100644 --- a/python/tvm/topi/bifrost/conv2d.py +++ b/python/tvm/topi/bifrost/conv2d.py @@ -477,7 +477,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index 4863a06b728d..3d05058ff52c 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -46,7 +46,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): data, kernel = tinfos out_dtype = out_type.dtype - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/cuda/conv3d_alter_op.py b/python/tvm/topi/cuda/conv3d_alter_op.py index faf73e77255a..c7ec7cb21fcf 100644 --- a/python/tvm/topi/cuda/conv3d_alter_op.py +++ b/python/tvm/topi/cuda/conv3d_alter_op.py @@ -35,7 +35,7 @@ def _alter_conv3d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv3d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/intel_graphics/conv2d_alter_op.py b/python/tvm/topi/intel_graphics/conv2d_alter_op.py index 0b59a849c2c9..199d984af1e4 100644 --- a/python/tvm/topi/intel_graphics/conv2d_alter_op.py +++ b/python/tvm/topi/intel_graphics/conv2d_alter_op.py @@ -35,7 +35,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): cfg = dispatch_ctx.query(target, None) workload = cfg.workload else: - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/mali/conv2d.py b/python/tvm/topi/mali/conv2d.py index f3ef55b9a30c..051914113a5b 100644 --- a/python/tvm/topi/mali/conv2d.py +++ b/python/tvm/topi/mali/conv2d.py @@ -531,7 +531,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): data, kernel = tinfos out_dtype = out_type.dtype - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/x86/conv2d_alter_op.py b/python/tvm/topi/x86/conv2d_alter_op.py index 8e47dff37ce6..3f2df655a615 100644 --- a/python/tvm/topi/x86/conv2d_alter_op.py +++ b/python/tvm/topi/x86/conv2d_alter_op.py @@ -57,7 +57,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): cfg = dispatch_ctx.query(target, None) workload = cfg.workload else: - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 8db84497f82d..1d64261a50d7 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -35,7 +35,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): M, K = get_const_tuple(data_tensor.shape) N, _ = get_const_tuple(weight_tensor.shape) - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.dense"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ef82ed617508..7005e94c2411 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -33,7 +33,7 @@ #include "../../target/func_registry_generator.h" #include "../../target/source/codegen_source_base.h" -#include "compile_engine.h" +#include "te_compiler.h" #include "utils.h" namespace tvm { @@ -295,8 +295,6 @@ class RelayBuildModule : public runtime::ModuleNode { executor_ = executor; CheckAndUpdateHostConsistency(&targets_, &target_host_); BuildRelay(mod, params_, mod_name); - // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096. - CompileEngine::Global()->Clear(); } protected: diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc deleted file mode 100644 index 0e7af2278375..000000000000 --- a/src/relay/backend/compile_engine.cc +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/compile_engine.cc - * \brief Internal compilation engine. - */ -#include "compile_engine.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "../../runtime/meta_data.h" -#include "../transforms/pass_utils.h" -#include "te_compiler_cache.h" -#include "utils.h" - -namespace tvm { -namespace relay { - -TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); - -class CompileEngineImpl : public CompileEngineNode { - public: - // Lower the function. - CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) { - return LowerInternal(key, mangle_fn)->cached_func; - } - - CachedFunc Lower(const CCacheKey& key, const String mod_name) { - auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); }; - - return Lower(key, mangle_fn); - } - - // For now, build one module per function. - PackedFunc JIT(const CCacheKey& key) final { - auto mangle_fn = [](String name) { return name; }; - CCacheValue value = LowerInternal(key, mangle_fn); - if (value->packed_func != nullptr) return value->packed_func; - auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); - value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); - return value->packed_func; - } - - CachedFunc LowerShapeFunc(const CCacheKey& key) final { - return LowerShapeFuncInternal(key)->cached_func; - } - - Array LowerExternalFunctions() { - Array ret; - std::unordered_map cached_symbol; - std::vector cached_ext_funcs; - for (const auto& it : cache_) { - auto src_func = it.first->source_func; - ICHECK(src_func.defined()); - - if (src_func->GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->GetAttr(attr::kCompiler); - ICHECK(code_gen.defined()) << "No external codegen is set"; - std::string code_gen_name = code_gen.value(); - cached_ext_funcs.push_back(it.first); - - auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" - << AsText(src_func, false) << "\n" - << "Functions with external codegen must have the " - << tvm::attr::kGlobalSymbol << " attr set."; - - std::string sn = symbol_name.value(); - if (!cached_symbol.count(sn)) { - cached_symbol[sn] = code_gen_name; - } else { - ICHECK_NE(cached_symbol[sn], code_gen_name) - << "Found duplicated symbol: " << sn << " for: " << code_gen_name; - } - - std::string ext_name = "relay.ext." + code_gen_name; - auto pf = tvm::runtime::Registry::Get(ext_name); - ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; - // No need to keep compiler attribute at this point, functions have been - // extracted for specific codegen. - src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); - runtime::Module ext_mod = (*pf)(src_func); - - // todo(@zhiics, @jroesch): Should this be a user visible error? - ICHECK(ext_mod.defined()) << "No external library was generated for " << ext_name - << "even though it was requested" - "by the annotated function " - << PrettyPrint(src_func); - - ret.push_back(ext_mod); - } - } - - // No need to cache external functions as we collected them all to create - // external runtime modules. - for (const auto& it : cached_ext_funcs) { - cache_.erase(it); - } - return ret; - } - - void Clear() final { cache_.clear(); } - - // List all items in the cache. - Array ListItems() { - std::lock_guard lock(mutex_); - Array items; - for (auto& kv : cache_) { - items.push_back(kv.first); - items.push_back(kv.second); - } - return items; - } - - // List all items in the shape_func_cache. - Array ListShapeFuncItems() { - std::lock_guard lock(mutex_); - Array items; - for (auto& kv : shape_func_cache_) { - items.push_back(kv.first); - items.push_back(kv.second); - } - return items; - } - - /*! - * \brief Get the cache key of the function that is being lowered currently - * \return the cache key - */ - CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; } - - private: - // implement lowered func - CCacheValue LowerInternal(const CCacheKey& key, std::function mangle_fn) { - std::lock_guard lock(mutex_); - CCacheValue value; - auto it = cache_.find(key); - if (it != cache_.end()) { - it->second->use_count += 1; - if (it->second->cached_func.defined()) return it->second; - value = it->second; - } else { - value = CCacheValue(make_object()); - value->use_count = 0; - if (!backend::IsCompileEngineCacheDisabled()) { - cache_[key] = value; - } - } - cur_ccache_key_ = key; - - // No need to lower external functions for now. We will invoke the external - // codegen tool once and lower all functions together. - if (key->source_func->GetAttr(attr::kCompiler).defined()) { - auto ir_module = IRModule(); - const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(name_node.defined()) << "External function has not been attached a name yet."; - auto func_name = std::string(name_node.value()); - auto target = Target("ext_dev"); - auto global_var = GlobalVar(func_name); - global_var->checked_type_ = key->source_func->checked_type(); - ir_module->Add(global_var, key->source_func); - value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); - return value; - } - - // Enforce use the target. - With target_scope(key->target); - - ICHECK(!value->cached_func.defined()); - auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) { - return GetUniqueName(mangle_fn(name), &name_map_); - }); - - // Skip lowering for device copy node. - const Expr body = (key->source_func)->body; - if (const CallNode* call_node = body.as()) { - if (call_node->attrs.as()) { - value->cached_func = cfunc; - return value; - } - } - - // NOTE: array will copy on write. - Array all_args = Array(cfunc->inputs); - for (te::Tensor arg : cfunc->outputs) { - all_args.push_back(arg); - } - // lower the function - std::unordered_map binds; - auto func_name = cfunc->prim_fn_var->name_hint; - cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); - value->cached_func = cfunc; - - return value; - } - - // implement lowered shape func - CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { - std::lock_guard lock(mutex_); - CCacheValue value; - auto it = shape_func_cache_.find(key); - if (it != shape_func_cache_.end()) { - it->second->use_count += 1; - if (it->second->cached_func.defined()) return it->second; - value = it->second; - } else { - value = CCacheValue(make_object()); - value->use_count = 0; - shape_func_cache_[key] = value; - } - // Enforce use the target. - With target_scope(key->target); - - ICHECK(!value->cached_func.defined()); - using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); - - auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { - return GetUniqueName(name, &name_map_); - }); - - value->cached_func = cached_func; - return value; - } - - /*! \brief compiler cache lock*/ - std::mutex mutex_; - /*! \brief internal name map to get an unique name */ - std::unordered_map name_map_; - /*! \brief internal compiler cache */ - std::unordered_map cache_; - /*! \brief internal compiler cache for shape funcs */ - std::unordered_map shape_func_cache_; - /*! \brief the cache key of the function that is being lowered currently*/ - CCacheKey cur_ccache_key_; -}; - -/*! \brief The global compile engine */ -CompileEngine& CompileEngine::Global() { - // intentionally allocate raw pointer to avoid - // free during destructuion. - static CompileEngine* inst = new CompileEngine(make_object()); - return *inst; -} - -TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.disable_compile_engine_cache", Bool); - -TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") - .set_body_typed([](tvm::Array outputs, OpImplementation impl) { - return LoweredOutput(outputs, impl); - }); - -TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") - .set_body_typed([](Function source_func, Target target) { - return CCacheKey(source_func, target); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() { - return CompileEngine::Global(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](CompileEngine self) { - self->Clear(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") - .set_body_typed([](CompileEngine self, CCacheKey key, const String mod_name) { - return self->Lower(key, mod_name); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") - .set_body_typed([](CompileEngine self) { return self->LowerExternalFunctions(); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->JIT(key); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->ListItems(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListShapeFuncItems") - .set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->ListShapeFuncItems(); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGetCurrentCCacheKey") - .set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->GetCurrentCCacheKey(); - }); - -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h deleted file mode 100644 index 4afdc6d30485..000000000000 --- a/src/relay/backend/compile_engine.h +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/compile_engine.h - * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. - * - * This layer represents the older design of the Relay compilation flow and is being deprecated - * in favor of te_compiler.h which is a migration step towards a standard pass based lowering of - * Relay functions. - * - */ -#ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ -#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "te_compiler_cache.h" - -namespace tvm { -namespace relay { - -using namespace tvm::relay::tec; - -/*! - * \brief Backend compilation engine for - * low level code generation. - */ -class CompileEngineNode : public Object { - public: - /*! \brief destructor */ - virtual ~CompileEngineNode() {} - /*! - * \brief Get lowered result. - * \param key The key to the cached function. - * \param mod_name The mangling function for mangling names. - * \return The result. - */ - virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; - - /*! - * \brief Get lowered result. - * \param key The key to the cached function. - * \param mod_name The module name to mangle the functions. - * \return The result. - */ - virtual CachedFunc Lower(const CCacheKey& key, const String mangle_fn) = 0; - /*! - * \brief Just in time compile to get a PackedFunc. - * \param key The key to the cached function. - * \return The result. - */ - virtual PackedFunc JIT(const CCacheKey& key) = 0; - /*! - * \brief Lower the shape function. - * \param key The key to the cached function. - * \return The result. - */ - virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; - /*! - * \brief Lower the external function using external codegen tools. - * \return The runtime moduels for each needed external codegen tool. - */ - virtual tvm::Array LowerExternalFunctions() = 0; - - /*! \brief clear the cache. */ - virtual void Clear() = 0; - - // VisitAttrs - void VisitAttrs(AttrVisitor*) {} - - static constexpr const char* _type_key = "relay.CompileEngine"; - TVM_DECLARE_FINAL_OBJECT_INFO(CompileEngineNode, Object); -}; - -/*! \brief cache entry used in compile engine */ -class CompileEngine : public ObjectRef { - public: - CompileEngine() {} - explicit CompileEngine(ObjectPtr n) : ObjectRef(n) {} - CompileEngineNode* operator->() { return static_cast(get_mutable()); } - using ContainerType = CompileEngineNode; - /*! \brief The global compile engine. */ - TVM_DLL static CompileEngine& Global(); -}; - -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index ef89fd9c9c6c..a596e09907d5 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -37,7 +37,7 @@ #include "../op/annotation/annotation.h" #include "../transforms/pass_utils.h" -#include "./te_compiler.h" +#include "te_compiler.h" namespace tvm { namespace relay { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 445602540dbb..a8c27a126032 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -313,6 +313,45 @@ TECompiler::TECompiler() { data_ = object; } +/*! \brief The global TE compiler */ +TECompiler& TECompiler::Global() { + static TECompiler* inst = new TECompiler(make_object()); + return *inst; +} +TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() { + return TECompiler::Global(); +}); + +TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") + .set_body_typed([](Function source_func, Target target) { + return CCacheKey(source_func, target); + }); + +TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") + .set_body_typed([](tvm::Array outputs, OpImplementation impl) { + return LoweredOutput(outputs, impl); + }); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerClear").set_body_typed([](TECompiler self) { + self->Clear(); +}); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerLower") + .set_body_typed([](TECompiler self, CCacheKey key, const String mod_name) { + return self->Lower(key, mod_name); + }); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerJIT") + .set_body_typed([](TECompiler self, CCacheKey key) { return self->JIT(key); }); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerListItems").set_body_typed([](TECompiler self) { + TECompilerImpl* ptr = dynamic_cast(self.operator->()); + ICHECK(ptr != nullptr); + return ptr->ListItems(); +}); + using AnalysisRemapping = std::unordered_map; std::tuple IsDeviceCopy(const Function& func) { diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 248fd40f98eb..e3b7d46457ad 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -127,6 +127,7 @@ class TECompiler : public ObjectRef { explicit TECompiler(ObjectPtr n) : ObjectRef(n) {} TECompilerNode* operator->() { return static_cast(get_mutable()); } using ContainerType = TECompilerNode; + TVM_DLL static TECompiler& Global(); }; /*! @@ -193,7 +194,7 @@ IRModule LowerTE( * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower - * \returns The pass which lowers primative functions to TIR + * \returns The pass which lowers primitive functions to TIR */ transform::Pass LowerTEPass(TargetMap targets, const String& module_name, std::function process_fn); diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 47ba96b2c77e..7975ef873173 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -62,7 +62,6 @@ struct LoweredOutputNode : public Object { v->Visit("outputs", &outputs); v->Visit("implementation", &implementation); } - static constexpr const char* _type_key = "relay.LoweredOutput"; TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); }; diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 6d59b858927c..febb550d45c0 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -427,15 +427,6 @@ inline bool IsAutoSchedulerEnabled() { .value(); } -/*! - * \brief Return whether the compile engine cache is disabled in the pass context. - */ -inline bool IsCompileEngineCacheDisabled() { - return transform::PassContext::Current() - ->GetConfig("relay.backend.disable_compile_engine_cache", Bool(false)) - .value(); -} - /*! * \brief Get the sequence of Relay optimization passes based on backend type. * The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index 7a86af8aeffa..c538dac048b3 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -34,7 +34,7 @@ #include #include -#include "../backend/compile_engine.h" +#include "../backend/te_compiler.h" #include "pattern_utils.h" namespace tvm { @@ -126,7 +126,8 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - PrimFuncFor(GetRef(func), Target::Current(), [](std::string name) { return name; }); + tec::PrimFuncFor(GetRef(func), Target::Current(), + [](std::string name) { return name; }); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 3cd5df613f4a..4e24434642d8 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -41,7 +41,7 @@ namespace runtime { struct TypeInfo { /*! \brief The current index. */ uint32_t index{0}; - /*! \brief Index of the parent in the type hierachy */ + /*! \brief Index of the parent in the type hierarchy */ uint32_t parent_index{0}; // NOTE: the indices in [index, index + num_reserved_slots) are // reserved for the child-class of this type. @@ -58,7 +58,7 @@ struct TypeInfo { }; /*! - * \brief Type context that manages the type hierachy information. + * \brief Type context that manages the type hierarchy information. */ class TypeContext { public: diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py index f151a85ec5b1..e582874d1de2 100644 --- a/tests/python/contrib/test_arm_compute_lib/infrastructure.py +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -184,7 +184,7 @@ def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_parti ), "Got {} Arm Compute Library partitions, expected {}".format( partition_count, acl_partitions ) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() return relay.build(mod, target=target, params=params) diff --git a/tests/python/contrib/test_bnns/infrastructure.py b/tests/python/contrib/test_bnns/infrastructure.py index 46bd049402a9..5a12b0487408 100644 --- a/tests/python/contrib/test_bnns/infrastructure.py +++ b/tests/python/contrib/test_bnns/infrastructure.py @@ -142,7 +142,7 @@ def build_module(mod, target, params=None, enable_bnns=True, tvm_ops=0): with tvm.transform.PassContext(opt_level=3): if enable_bnns: mod = partition_for_bnns(mod) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() return relay.build(mod, target=target, target_host=target, params=params) diff --git a/tests/python/contrib/test_ethosn/infrastructure.py b/tests/python/contrib/test_ethosn/infrastructure.py index 92e8f11a2312..c5ebde4b9c61 100644 --- a/tests/python/contrib/test_ethosn/infrastructure.py +++ b/tests/python/contrib/test_ethosn/infrastructure.py @@ -149,7 +149,7 @@ def build(mod, params, npu=True, expected_host_ops=0, npu_partitions=1): npu_partitions : int, optional The number of Ethos-N partitions expected. """ - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() with tvm.transform.PassContext( opt_level=3, config={"relay.ext.ethos-n.options": {"variant": get_ethosn_variant()}} ): @@ -262,7 +262,7 @@ def test_error(mod, params, err_msg): except tvm.error.TVMError as e: caught = e.args[0] finally: - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() assert caught is not None assert err_msg in caught, caught diff --git a/tests/python/contrib/test_vitis_ai/infrastructure.py b/tests/python/contrib/test_vitis_ai/infrastructure.py index e87d4f874630..578ac37da25b 100644 --- a/tests/python/contrib/test_vitis_ai/infrastructure.py +++ b/tests/python/contrib/test_vitis_ai/infrastructure.py @@ -99,7 +99,7 @@ def build_module( ), "Got {} Vitis-AI partitions, expected {}".format( partition_count, vitis_ai_partitions ) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() return relay.build(mod, target, params=params) diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 746f595a4422..276cad375357 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -33,8 +33,10 @@ import tvm from tvm import relay +from tvm import te from tvm.contrib import utils, graph_executor -from tvm.relay.backend import compile_engine +from tvm.relay.backend import te_compiler +from tvm.relay.backend.te_compiler import TECompiler from tvm.relay.backend.utils import mangle_module_name from tvm.micro import export_model_library_format @@ -721,7 +723,6 @@ def compile_and_run( def generate_ref_data(mod, input_data, params=None, target="llvm"): """Generate reference data through executing the relay module""" - compile_engine.get().clear() with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params) diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index 22583eda4a40..7669d02cd536 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -41,7 +41,7 @@ def verify_func(func, data, ref_res, target_device=tvm.testing.enabled_targets() tvm.testing.assert_allclose(op_result.numpy(), ref_result, rtol=1e-5) else: tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() @tvm.testing.uses_gpu @@ -251,7 +251,8 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ verify_sparse_to_dense( [0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1] ) # floats - verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified + # default value not specified + verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) @pytest.mark.parametrize( diff --git a/tests/python/relay/test_json_runtime.py b/tests/python/relay/test_json_runtime.py index ca792204c835..c6eb7531f635 100644 --- a/tests/python/relay/test_json_runtime.py +++ b/tests/python/relay/test_json_runtime.py @@ -26,7 +26,7 @@ from tvm import relay, runtime from tvm.contrib import utils from tvm.relay import transform -from tvm.relay.backend import compile_engine +from tvm.relay.backend import te_compiler from tvm.relay.build_module import bind_params_by_name from tvm.relay.op.contrib.register import get_pattern_table @@ -47,7 +47,7 @@ def check_result( return # Run the reference result - compile_engine.get().clear() + te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): json, lib, param = relay.build(ref_mod, target=target, params=params) rt_mod = tvm.contrib.graph_executor.create(json, lib, device) @@ -61,7 +61,7 @@ def check_result( ref_result = out.numpy() def check_vm_result(): - compile_engine.get().clear() + te_compiler.get().clear() with relay.build_config(opt_level=3): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() @@ -71,7 +71,7 @@ def check_vm_result(): tvm.testing.assert_allclose(out.numpy(), ref_result, rtol=tol, atol=tol) def check_graph_executor_result(): - compile_engine.get().clear() + te_compiler.get().clear() with relay.build_config(opt_level=3): json, lib, param = relay.build(mod, target=target, params=params) rt_mod = tvm.contrib.graph_executor.create(json, lib, device) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index eaddd33678df..754c9d1c4a74 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1422,7 +1422,8 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ verify_sparse_to_dense( [0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1] ) # floats - verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified + # default value not specified + verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # negative test cases # sparse indices should be ints @@ -1757,7 +1758,7 @@ def verify_func(target, dev, func, data, ref_res): tvm.testing.assert_allclose(op_result.numpy(), ref_result, rtol=1e-5) else: tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() def test_adv_index(target, dev, executor_kind): @@ -1970,7 +1971,8 @@ def calc_numpy_unique(data, is_sorted=False): uniq = uniq[order].astype(data.dtype) inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") counts = counts[order].astype("int32") - index = np.sort(index) # In unsorted case, need to sort the index of first occurence + # In unsorted case, need to sort the index of first occurence + index = np.sort(index) return [ uniq.astype(data.dtype), index.astype("int32"), diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 93cd6f791765..5aba6229c5e2 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -22,6 +22,7 @@ import numpy as np import tvm +from tvm.relay.backend import te_compiler import tvm.relay.testing import tvm.relay.op as reg from tvm import relay @@ -29,7 +30,6 @@ from tvm.relay import transform from tvm.relay.testing import byoc from tvm.contrib import utils -from tvm.relay.backend import compile_engine from tvm.relay.expr_functor import ExprMutator from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.op.contrib.register import get_pattern_table @@ -143,7 +143,7 @@ def update_lib(lib): return lib def check_vm_result(): - compile_engine.get().clear() + te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() @@ -157,7 +157,7 @@ def check_vm_result(): tvm.testing.assert_allclose(out.numpy(), ref, rtol=tol, atol=tol) def check_graph_executor_result(): - compile_engine.get().clear() + te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): json, lib, param = relay.build(mod, target=target, params=params) lib = update_lib(lib) @@ -508,7 +508,7 @@ def test_extern_dnnl_mobilenet(): ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( i_data, **params ) - compile_engine.get().clear() + te_compiler.get().clear() check_result(mod, {"data": i_data}, (1, 1000), ref_res.numpy(), tol=1e-5, params=params) @@ -950,7 +950,7 @@ def test_exec(mod, params, ref_mod, ref_params, out_shape): ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( i_data, **ref_params ) - compile_engine.get().clear() + te_compiler.get().clear() mod = get_partitoned_mod(mod, params, dnnl_patterns) diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_relay_te_compiler.py similarity index 93% rename from tests/python/relay/test_backend_compile_engine.py rename to tests/python/relay/test_relay_te_compiler.py index 092cae01f568..f8498ae83648 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_relay_te_compiler.py @@ -21,6 +21,7 @@ from tvm import relay from tvm import autotvm from tvm import topi +from tvm.relay.backend import te_compiler from tvm.relay.testing import run_infer_type from tvm.relay.testing.temp_op_attr import TempOpAttr import tvm.testing @@ -98,7 +99,7 @@ def _get_impls(dshape, wshape): weight = relay.var("wshape", shape=wshape) out = relay.nn.conv2d(data, weight, padding=(1, 1)) out = run_infer_type(out) - return relay.backend.compile_engine.get_valid_implementations( + return relay.backend.te_compiler.get_valid_implementations( relay.op.get("nn.conv2d"), out.attrs, [te.placeholder(dshape), te.placeholder(wshape)], @@ -121,7 +122,7 @@ def _select_impl(dshape, wshape, use_autotvm=False): weight = relay.var("wshape", shape=wshape) out = relay.nn.conv2d(data, weight, padding=(1, 1)) out = run_infer_type(out) - return relay.backend.compile_engine.select_implementation( + return relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), out.attrs, [te.placeholder(dshape), te.placeholder(wshape)], @@ -161,8 +162,8 @@ def _select_impl(dshape, wshape, use_autotvm=False): assert impl.name == "conv2d_1" -def test_compile_engine(): - engine = relay.backend.compile_engine.get() +def test_te_compiler(): + tec = relay.backend.te_compiler.get() def get_func(shape): x = relay.var("x", shape=shape) @@ -173,31 +174,30 @@ def get_func(shape): mod = relay.transform.InferType()(mod) return mod["main"] - z1 = engine.lower(get_func((10,)), "llvm") - z2 = engine.lower(get_func((10,)), "llvm") - z3 = engine.lower(get_func(()), "llvm") + z1 = tec.lower(get_func((10,)), "llvm") + z2 = tec.lower(get_func((10,)), "llvm") + z3 = tec.lower(get_func(()), "llvm") assert z1.same_as(z2) assert not z3.same_as(z1) if tvm.testing.device_enabled("cuda"): - z4 = engine.lower(get_func(()), "cuda") + z4 = tec.lower(get_func(()), "cuda") assert not z3.same_as(z4) # Test JIT target for target in ["llvm"]: dev = tvm.device(target) if tvm.testing.device_enabled(target): - f = engine.jit(get_func((10,)), target) + f = tec.jit(get_func((10,)), target) x = tvm.nd.array(np.ones(10).astype("float32"), device=dev) y = tvm.nd.empty((10,), device=dev) f(x, y) tvm.testing.assert_allclose(y.numpy(), x.numpy() * 3) - engine.dump() -# Note: Once compile engine is removed, we should keep this test so that +# Note: Once the te compiler is removed, we should keep this test so that # we make sure that opt_level=0 passes are being called correctly. def test_compile_placeholder_bypass(): - engine = relay.backend.compile_engine.get() + te_compiler = relay.backend.te_compiler.get() x = relay.var("x", shape=(2, 3)) y = relay.var("y", shape=(2, 3)) z = relay.var("z", shape=(2, 3)) @@ -264,7 +264,7 @@ def test_compile_nhwc_pack(): if __name__ == "__main__": test_get_valid_implementations() test_select_implementation() - test_compile_engine() + test_te_compiler() test_compile_placeholder_bypass() test_compile_injective_with_tuple() test_compile_tuple_dup() diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index b5620d748d8a..9b95266d3287 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -63,7 +63,8 @@ def check(m, n, target_bits, target_dtype): # const shape # i32 -> i32 check(2, 2, 32, "int32") - check(2 ** 16, 2 ** 16, 32, "int32") # i32 + i32 is not promoted to i64 even if overflow + # i32 + i32 is not promoted to i64 even if overflow + check(2 ** 16, 2 ** 16, 32, "int32") # i64 -> i32 check(const(2, dtype="int64"), const(2, dtype="int64"), 32, "int32") check(const(2 ** 16, dtype="int64"), const(2 ** 16, dtype="int64"), 32, "int64") @@ -185,7 +186,7 @@ def check(m, n, target_bits, target_dtype): def test_relay_basic(): - engine = relay.backend.compile_engine.get() + engine = relay.backend.te_compiler.get() def check(shapex, shapey, target_bits, target_dtype): x = relay.var("x", shape=shapex) @@ -227,7 +228,7 @@ def check(shapex, shapey, target_bits, target_dtype): def test_relay_take(): - engine = relay.backend.compile_engine.get() + engine = relay.backend.te_compiler.get() def check(shape, index, target_bits, target_dtype): x = relay.var("x", shape=shape) From e830a1f31d50e93ca63dbdef8e497c887862bdb6 Mon Sep 17 00:00:00 2001 From: Raghav Chakravarthy Date: Fri, 22 Oct 2021 17:23:52 -0400 Subject: [PATCH 75/84] [Code Style] Changed code to match the tvm code style conventions. (#9040) * [Code Style] Changed code to match the tvm code style conventions. [Issue] While reviewing the tvm code, I noticed some naming convention issues in the diag_ctx_ and current_func variables. Variable current_func should be current_func_ because it is a class variable Variable diag_ctx_ should be diag_ctx , because it is a public variable [Solution] Changed the variables to match the tvm code style conventions * addressed comments * removed debug logic * fixed plint issue * fixed building issue * fixed whitespace issue * fixed linting error in type_solver.cc --- src/relay/analysis/type_solver.cc | 26 ++++++++++++-------------- src/relay/analysis/type_solver.h | 8 ++------ src/relay/transforms/type_infer.cc | 2 +- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 22e2e9a71040..1421906a3bbb 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -131,12 +131,11 @@ class TypeSolver::Unifier : public TypeFunctor { Type resolved = this->VisitType(rhs->resolved_type, lhs->resolved_type); if (!resolved.defined()) { - solver_->diag_ctx_.Emit( - Diagnostic::Error(this->span) - << "The Relay type checker is unable to show the following types match.\n" - << "In particular " - << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" - << PrettyPrint(rhs->resolved_type) << "`"); + solver_->Emit(Diagnostic::Error(this->span) + << "The Relay type checker is unable to show the following types match.\n" + << "In particular " + << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" + << PrettyPrint(rhs->resolved_type) << "`"); return lhs->resolved_type; } else { TypeNode* top = solver_->GetTypeNode(resolved); @@ -233,11 +232,10 @@ class TypeSolver::Unifier : public TypeFunctor { tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { - this->solver_->diag_ctx_.Emit(Diagnostic::Error(this->span) - << "tensor type `" << PrettyPrint(tt1) << "` has " - << tt1->shape.size() << " dimensions, while `" - << PrettyPrint(tt2) << "` has " << tt2->shape.size() - << " dimensions"); + this->solver_->Emit(Diagnostic::Error(this->span) + << "tensor type `" << PrettyPrint(tt1) << "` has " << tt1->shape.size() + << " dimensions, while `" << PrettyPrint(tt2) << "` has " + << tt2->shape.size() << " dimensions"); return Type(nullptr); } @@ -266,7 +264,7 @@ class TypeSolver::Unifier : public TypeFunctor { err << "dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch) << " does not match " << std::get<2>(mismatch) << "."; } - this->solver_->diag_ctx_.Emit(err); + this->solver_->Emit(err); return Type(nullptr); } @@ -526,7 +524,7 @@ class TypeSolver::Merger : public TypeFunctor { // constructor TypeSolver::TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx) : reporter_(make_object(this)), - current_func(current_func), + current_func_(current_func), diag_ctx_(diag_ctx), module_(diag_ctx->module) { ICHECK(module_.defined()); @@ -618,7 +616,7 @@ bool TypeSolver::Solve() { rnode->resolved = resolved; } catch (const CompileError& err) { - this->diag_ctx_.Emit(Diagnostic::Error(rnode->span) << err.what()); + this->Emit(Diagnostic::Error(rnode->span) << err.what()); rnode->resolved = false; } catch (const Error& e) { ICHECK(false) << e.what(); diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index 56cea60ceeda..3bde1a1e3746 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -94,7 +94,7 @@ class TypeSolver { * \brief Report a diagnostic. * \param diag The diagnostic to report. */ - void EmitDiagnostic(const Diagnostic& diag); + void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); } private: class OccursChecker; @@ -176,13 +176,9 @@ class TypeSolver { /*! \brief Reporter that reports back to self */ TypeReporter reporter_; /*! \brief The global representing the current function. */ - GlobalVar current_func; - - public: + GlobalVar current_func_; /*! \brief The diagnostic context. */ DiagnosticContext diag_ctx_; - - private: /*! \brief The module. */ IRModule module_; diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 5ca6d86b1d52..6d74e48e871e 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -666,7 +666,7 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Type checked_type = solver_->Resolve(it->second.checked_type); if (checked_type.as() != nullptr) { - this->solver_->diag_ctx_.Emit( + this->solver_->Emit( Diagnostic::Error(op->span) << "The type inference pass was unable to infer a type for this expression.\n" << "This usually occurs when an operator call is under constrained in some way," From 4fb6fa5f3817af3415e0b493293522bd8bf52b5b Mon Sep 17 00:00:00 2001 From: Jason Date: Sat, 23 Oct 2021 05:25:16 +0800 Subject: [PATCH 76/84] [Frontend][PaddlePaddle] Add autopad for conv/pool (#9295) * Add autopad for conv/pool * add autopad for conv/pool * fix pylint warning * add some annotations * add som annotations * add som annotations * Refactor autopad in the onnx.py and paddlepaddle.py to relay/frontend/common.py * add comment for conv2d Co-authored-by: heliqi <1101791222@qq.com> --- python/tvm/relay/frontend/common.py | 74 +++++++++ python/tvm/relay/frontend/onnx.py | 79 +--------- python/tvm/relay/frontend/paddlepaddle.py | 142 +++++++++++------- .../frontend/paddlepaddle/test_forward.py | 119 +++++++++++---- 4 files changed, 253 insertions(+), 161 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 825a586918f8..cf579923e301 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -28,6 +28,7 @@ from .. import function as _function from .. import transform as _transform from .. import op as _op +from .. import ty as _ty from .. import analysis # pylint: disable=invalid-name @@ -594,6 +595,16 @@ def try_infer_value(val, on_success=None, on_failure=None): return val, False +def shape_of(x, dtype="int64"): + """Get shape of a tensor.""" + + ttype = infer_type(x).checked_type + if not _ty.is_dynamic(ttype): + shape = list(ttype.shape) + return _expr.const(shape, dtype) + return _op.shape_of(x, dtype) + + def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"): return _expr.var(name_hint, type_annotation, shape, dtype) @@ -837,6 +848,69 @@ def lstm_cell( return outputs_list, hidden_state, cell_state +def autopad( + data, + strides, + kernel_shape, + dilations=(1, 1), + pad_type="constant", + deconv=False, + mode="SAME_UPPER", + pad_value=0.0, +): + """ + Perform autopadding with dynamic input shapes + """ + # get attributes as constants + strides = _op.const(np.array(strides), dtype="int64") + dilated_kernel_shape = _op.const( + np.array( + [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] + ), + dtype="int64", + ) + # get input shape + ndim = len(infer_shape(data)) + shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) + + # set up integer constants + zero = _op.const(0, dtype="int64") + one = _op.const(1, dtype="int64") + two = _op.const(2, dtype="int64") + + # Calculate total padding + mod = _op.mod(shape, strides) + + left = _op.maximum(dilated_kernel_shape - strides, zero) + right = _op.maximum(dilated_kernel_shape - mod, zero) + + total_pad = _op.where(_op.equal(mod, zero), left, right) + if deconv: + total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad + + # split total padding into before and after + pad_before = _op.floor_divide(total_pad, two) + pad_after = total_pad - pad_before + + # combine + if "LOWER" in mode: + pad = _op.concatenate( + [_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1 + ) + else: + pad = _op.concatenate( + [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 + ) + + # pad N and C with zeros + pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) + + if isinstance(pad_value, (float, int)): + pad_value = _op.const(pad_value) + + return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) + + def ensure_scalar_shape(x): """ Assume that `x` is a tensor with one element (regardless of tensor rank). diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5c112c7dfce0..3c88f659f6f0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -38,6 +38,7 @@ from .. import ty as _ty from .. import vision as _vision from .common import ( + autopad, AttrCvt, Renamer, ensure_scalar_shape, @@ -51,6 +52,7 @@ infer_value, lstm_cell, new_var, + shape_of, try_resolve_var_to_const, unbind, ) @@ -315,7 +317,6 @@ def _run_calculation(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], [1] * ndim, - ndim, pad_value=pad_val, mode=attr["auto_pad"], ) @@ -411,69 +412,6 @@ def _impl_v1(cls, inputs, attr, params): return AttrCvt(op_name="instance_norm")(inputs, attr, params) -def autopad( - data, - strides, - kernel_shape, - dilations, - ndim, - pad_type="constant", - deconv=False, - mode="SAME_UPPER", - pad_value=0.0, -): - """ - Perform autopadding with dynamic input shapes - """ - # get attributes as constants - strides = _op.const(np.array(strides), dtype="int64") - dilated_kernel_shape = _op.const( - np.array( - [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] - ), - dtype="int64", - ) - # get input shape - shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) - - # set up integer constants - zero = _op.const(0, dtype="int64") - one = _op.const(1, dtype="int64") - two = _op.const(2, dtype="int64") - - # Calculate total padding - mod = _op.mod(shape, strides) - - left = _op.maximum(dilated_kernel_shape - strides, zero) - right = _op.maximum(dilated_kernel_shape - mod, zero) - - total_pad = _op.where(_op.equal(mod, zero), left, right) - if deconv: - total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad - - # split total padding into before and after - pad_before = _op.floor_divide(total_pad, two) - pad_after = total_pad - pad_before - - # combine - if "LOWER" in mode: - pad = _op.concatenate( - [_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1 - ) - else: - pad = _op.concatenate( - [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 - ) - - # pad N and C with zeros - pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) - - if isinstance(pad_value, (float, int)): - pad_value = _op.const(pad_value) - - return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) - - class Conv(OnnxOpConverter): """Operator converter for Conv.""" @@ -501,7 +439,6 @@ def _impl_v1(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, mode=attr["auto_pad"], ) elif attr["auto_pad"] == "VALID": @@ -582,7 +519,6 @@ def _impl_v1(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, deconv=True, mode=attr["auto_pad"], ) @@ -974,7 +910,6 @@ def _impl_v1(cls, inputs, attr, params): attr["strides"], attr["kernel_shape"], [1] * ndim, - ndim, mode=attr["auto_pad"], ) elif attr["auto_pad"] == "VALID": @@ -1410,14 +1345,6 @@ def _impl_v9(cls, inputs, attr, params): return out -def shape_of(x, dtype="int64"): - ttype = infer_type(x).checked_type - if not _ty.is_dynamic(ttype): - shape = list(ttype.shape) - return _expr.const(shape, dtype) - return _op.shape_of(x, dtype) - - class Shape(OnnxOpConverter): """Operator converter for Shape.""" @@ -3440,7 +3367,6 @@ def _impl_v10(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, pad_value=x_zero_point.data, mode=attr["auto_pad"], ) @@ -3810,7 +3736,6 @@ def _impl_v10(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, pad_value=data_zp, mode=attr["auto_pad"], ) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index c32449546f77..ef361d6c55e8 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -18,6 +18,7 @@ # pylint: disable=import-outside-toplevel """Paddle: PArallel Distributed Deep LEarning.""" +import warnings import numpy as np import tvm @@ -31,11 +32,13 @@ from .. import ty as _ty from .. import op as _op from .common import ( + autopad, fold_constant, get_relay_op, infer_shape, infer_type, infer_value, + shape_of, try_infer_value, new_var, ) @@ -43,20 +46,6 @@ __all__ = ["from_paddle"] -def _get_pad_size(in_size, dilated_kernel_size, stride_size): - """Calculate the paddings size for Conv/Pool in SAME padding mode.""" - - if stride_size == 1 or in_size % stride_size == 0: - pad = max(dilated_kernel_size - stride_size, 0) - else: - pad = max(dilated_kernel_size - (in_size % stride_size), 0) - - pad_before = pad // 2 - pad_after = pad - pad_before - - return [pad_before, pad_after] - - def _dtype_shape_promotion(inputs): """Promote data type and shape for list of tensors.""" @@ -78,16 +67,6 @@ def _dtype_shape_promotion(inputs): return inputs -def shape_of(x, dtype="int32"): - """Get shape of a tensor.""" - - ttype = infer_type(x).checked_type - if not _ty.is_dynamic(ttype): - shape = list(ttype.shape) - return _expr.const(np.array(shape), dtype) - return _op.shape_of(x, dtype) - - def _convert_dtype_value(val): """Converts a Paddle type id to a string.""" @@ -248,24 +227,16 @@ def convert_conv2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - if strides[0] == 1 and strides[1] == 1: - pad_h = _get_pad_size(0, (k_h - 1) * dilations[0] + 1, strides[0]) - pad_w = _get_pad_size(0, (k_w - 1) * dilations[1] + 1, strides[1]) - else: - input_shape = shape_of(input_x) - h_w = _op.strided_slice(input_shape, [2], [4]) - try: - in_h, in_w = infer_value(h_w, g.get_params()).numpy().tolist() - except Exception as e: - msg = "Dynamic shape is not supported in SAME padding algorithm while stride!=1" - raise tvm.error.OpAttributeInvalid(msg) from e - pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) - pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) - paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + # Handle history issue of PaddlePaddle + # while padding_algorithm == "SAME" + # dilations will be set to [1, 1] + dilations = [1, 1] + input_x = autopad(input_x, strides, [k_h, k_w], dilations) + paddings = [0, 0] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] - if len(paddings) == 4: + elif len(paddings) == 4: paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] else: msg = 'Value {} in attribute "padding" of operator Conv is not "valid."' @@ -559,9 +530,9 @@ def convert_matmul(g, op, block): # This implemention almost keeps same with ONNX # Need to check input shape as batch matmul must be supported. - a_shape = shape_of(inputs[0]) + a_shape = shape_of(inputs[0], dtype="int32") a_rank = infer_shape(a_shape)[0] - b_shape = shape_of(inputs[1]) + b_shape = shape_of(inputs[1], dtype="int32") b_rank = infer_shape(b_shape)[0] # When performing a batch matmul, we need to properly handle N-dim shapes. if a_rank > 2 or b_rank > 2: @@ -648,8 +619,8 @@ def convert_mul(g, op, block): y = g.get_node(op.input("Y")[0]) x_num_col_dims = op.attr("x_num_col_dims") y_num_col_dims = op.attr("y_num_col_dims") - x_shape = shape_of(x) - y_shape = shape_of(y) + x_shape = shape_of(x, dtype="int32") + y_shape = shape_of(y, dtype="int32") x_dim = infer_shape(x_shape)[0] y_dim = infer_shape(y_shape)[0] if x_num_col_dims < 0: @@ -686,6 +657,39 @@ def convert_mul(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_padding(g, op, block): + """Operator converter for padding.""" + + input_x = g.get_node(op.input("X")[0]) + input_padding = op.input("Paddings") + if input_padding: + padding = g.get_node(input_padding[0]) + padding = infer_value(padding, g.get_params()).numpy().tolist() + else: + padding = op.attr("paddings") + padding = op.attr("paddings") + value = op.attr("value") + data_format = op.attr("data_format") + mode = op.attr("mode") + assert mode != "circular", "Don't support mod='circular' for PaddlePaddle's padding" + if mode == "replicate": + mode = "edge" + + pad_len = len(padding) + new_paddings = [0] * (pad_len + 4) + for i in range(0, pad_len, 2): + index = -1 - i + if data_format[:2] != "NC": + index = -3 - i + new_paddings[index] = padding[i + 1] + new_paddings[index - 1] = padding[i] + + new_paddings = [new_paddings[i : i + 2] for i in range(0, len(new_paddings), 2)] + + out = _op.nn.pad(input_x, new_paddings, pad_value=value, pad_mode=mode) + g.add_node(op.output("Out")[0], out) + + def convert_pool2d(g, op, block): """Operator converter for pool2d.""" @@ -696,17 +700,19 @@ def convert_pool2d(g, op, block): paddings = op.attr("paddings") padding_algorithm = op.attr("padding_algorithm") pooling_type = op.attr("pooling_type") + if global_pooling: adaptive = True ksize = [1, 1] input_x = g.get_node(op.input("X")[0]) - in_h, in_w = infer_shape(input_x)[2:] + _, _, in_h, in_w = infer_shape(input_x) op_map = { "avg": "avg_pool2d", "max": "max_pool2d", } + strides = op.attr("strides") if isinstance(strides, int): strides = [strides, strides] @@ -718,22 +724,40 @@ def convert_pool2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - pad_h = _get_pad_size(in_h, ksize[0], strides[0]) - pad_w = _get_pad_size(in_w, ksize[1], strides[1]) - paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + input_x = autopad(input_x, strides, ksize) + paddings = [0, 0] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] - if len(paddings) == 4: + elif len(paddings) == 4: paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] else: msg = 'Value {} in attribute "padding" of operator Pool2d is not "valid."' raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + # handle with special case + # while kernel size less than input size + # shrink kernel size to input size + if not isinstance(in_h, _op.Expr) and in_h < ksize[0]: + ksize[0] = in_h + if not isinstance(in_w, _op.Expr) and in_w < ksize[1]: + ksize[1] = in_w + if not adaptive: - out = getattr(_op.nn, op_map[pooling_type])( - input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode - ) + if pooling_type == "avg": + exclusive = op.attr("exclusive") + out = _op.nn.avg_pool2d( + input_x, + pool_size=ksize, + strides=strides, + padding=paddings, + ceil_mode=ceil_mode, + count_include_pad=not exclusive, + ) + else: + out = getattr(_op.nn, op_map[pooling_type])( + input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode + ) else: out = getattr(_op.nn, "adaptive_" + op_map[pooling_type])(input_x, output_size=ksize) g.add_node(op.output("Out")[0], out) @@ -796,7 +820,7 @@ def convert_shape(g, op, block): """Operator converter for shape.""" x = g.get_node(op.input("Input")[0]) - out = shape_of(x) + out = shape_of(x, dtype="int32") g.add_node(op.output("Out")[0], out) @@ -854,6 +878,17 @@ def convert_softmax(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_squeeze(g, op, block): + """Operator converter for squeeze2.""" + + x = g.get_node(op.input("X")[0]) + axes = op.attr("axes") + if not axes: + axes = None + x = _op.squeeze(x, axis=axes) + g.add_node(op.output("Out")[0], x) + + def convert_unsqueeze(g, op, block): """Operator converter for unsqueeze.""" @@ -904,6 +939,7 @@ def convert_unsqueeze(g, op, block): "matmul": convert_matmul, "matmul_v2": convert_matmul, "mul": convert_mul, + "pad3d": convert_padding, "pool2d": convert_pool2d, "relu": convert_unary_op, "reshape2": convert_reshape, @@ -911,6 +947,7 @@ def convert_unsqueeze(g, op, block): "shape": convert_shape, "slice": convert_slice, "softmax": convert_softmax, + "squeeze2": convert_squeeze, "tanh": convert_unary_op, "unsqueeze2": convert_unsqueeze, } @@ -1062,7 +1099,6 @@ def from_translated_layer(self, layer, shape_dict): def from_paddle(program_or_layer, shape_dict=None, scope=None): """Convert a PaddlePaddle model into an equivalent Relay Function. - PaddlePaddle Program/TranslatedLayer represent the computation graph of PaddlePaddle model, and PaddlePaddle scope stores all the weights of PaddlePaddle model. diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index e3d1fc9daf2b..b274d178c9c2 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -382,31 +382,37 @@ def cusum3(inputs): @tvm.testing.uses_gpu def test_forward_conv(): - conv2d_input_shape = [1, 3, 10, 10] - class Conv2D1(nn.Layer): - def __init__(self): + def __init__(self, stride=1, padding=0, dilation=1, groups=1, padding_mode="zeros"): super(Conv2D1, self).__init__() - self.conv = nn.Conv2D(3, 6, 7, bias_attr=True) + self.conv = nn.Conv2D( + 3, + 6, + 3, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + ) self.softmax = nn.Softmax() @paddle.jit.to_static def forward(self, inputs): return self.softmax(self.conv(inputs)) - class Conv2D2(nn.Layer): - def __init__(self): - super(Conv2D2, self).__init__() - self.conv = nn.Conv2D(3, 6, 7, groups=3, bias_attr=False) - self.softmax = nn.Softmax() - - @paddle.jit.to_static - def forward(self, inputs): - return self.softmax(self.conv(inputs)) + input_shapes = [[1, 3, 10, 10], [1, 3, 12, 12]] - conv2d_input_data = paddle.rand(conv2d_input_shape, dtype="float32") - verify_model(Conv2D1(), input_data=conv2d_input_data) - verify_model(Conv2D2(), input_data=conv2d_input_data) + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Conv2D1(), input_data=input_data) + verify_model(Conv2D1(stride=2, padding="VALID", dilation=3), input_data=input_data) + verify_model(Conv2D1(stride=2, padding="SAME", dilation=3), input_data=input_data) + verify_model( + Conv2D1(stride=2, padding=3, dilation=3, padding_mode="replicate"), + input_data=input_data, + ) + verify_model(Conv2D1(stride=2, padding="SAME", dilation=2, groups=3), input_data=input_data) @tvm.testing.uses_gpu @@ -538,6 +544,26 @@ def full2(inputs): verify_model(full2, input_data=[input_data]) +@tvm.testing.uses_gpu +def test_forward_squeeze(): + class Squeeze(nn.Layer): + def __init__(self, axis=None): + super(Squeeze, self).__init__() + self.axis = axis + + @paddle.jit.to_static + def forward(self, inputs): + return paddle.squeeze(inputs, axis=self.axis) + + input_shapes = [[1, 1, 3, 1, 5], [5, 1, 6]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Squeeze(axis=None), input_data=input_data) + verify_model(Squeeze(axis=1), input_data=input_data) + input_data = paddle.rand([1], dtype="float32") + verify_model(Squeeze(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_ones_like(): @paddle.jit.to_static @@ -722,24 +748,55 @@ def forward(self, input1, input2): @tvm.testing.uses_gpu def test_forward_pool2d(): - @paddle.jit.to_static - def pool2d1(inputs): - return nn.functional.avg_pool2d(inputs, kernel_size=2, stride=2, padding=0) + class Pool2D1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.avg_pool2d(inputs, kernel_size=2, stride=2, padding=0) - @paddle.jit.to_static - def pool2d2(inputs): - return nn.functional.adaptive_avg_pool2d(inputs, output_size=[3, 3]) + class Pool2D2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.adaptive_avg_pool2d(inputs, output_size=[3, 3]) + + class Pool2D3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.avg_pool2d( + inputs, + kernel_size=3, + stride=1, + padding=[1, 1], + exclusive=False, + divisor_override=2.5, + ) + + input_shapes = [[1, 2, 8, 8], [1, 3, 10, 10]] + for input_shape in input_shapes: + input_data = paddle.uniform(shape=input_shape, dtype="float32", min=-1, max=1) + verify_model(Pool2D1(), input_data=input_data) + verify_model(Pool2D2(), input_data=input_data) + verify_model(Pool2D3(), input_data=input_data) - @paddle.jit.to_static - def pool2d3(inputs): - return nn.functional.max_pool2d( - inputs, kernel_size=2, stride=2, padding=0, return_mask=True - ) - input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype="float32", min=-1, max=1) - verify_model(pool2d1, input_data=input_data) - verify_model(pool2d2, input_data=input_data) - # verify_model(pool2d3, input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_pad3d(): + class Pad3D(nn.Layer): + def __init__(self, padding=0, mode="constant", value=0.0, data_format="NCDHW"): + super(Pad3D, self).__init__() + self.pad3d = paddle.nn.Pad3D(padding, mode=mode, value=value, data_format=data_format) + + @paddle.jit.to_static + def forward(self, inputs): + return self.pad3d(inputs) + + input_shapes = [[1, 2, 2, 5, 5], [1, 2, 2, 5, 9]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Pad3D(padding=2), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1]), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], value=0.3), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], mode="reflect"), input_data=input_data) + verify_model(Pad3D(padding=3, mode="replicate"), input_data=input_data) @tvm.testing.uses_gpu From 1526ad1f6125a63410754f7acd9f7a1ae5df1c05 Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Fri, 22 Oct 2021 16:25:58 -0500 Subject: [PATCH 77/84] [UnitTest][Flaky] In test_report_serialization, compare csv. (#9275) * [UnitTest][Flaky] In test_report_serialization, compare csv. `str(report)` calls `ReportNode::AsTable()`, which includes aggregate values. Otherwise negligible differences in the computed value can be rounded differently after the round trip. This was first [noticed in CI](https://ci.tlcpack.ai/blue/organizations/jenkins/tvm/detail/PR-9194/7/pipeline/#step-246-log-1217) for an unrelated PR. Testing locally, this failure mode occurred 2 times out of 3000 trials. Switching to `report.csv()` avoids this issue, as it does not include aggregates. * Switched back to using AsTable(), but with column sums disabled. The .csv column headers are in arbitrary order, and do not test whether the `device_metrics` field has been serialized/deserialized correctly. * Added explicit sorting of columns to Report::AsTable --- include/tvm/runtime/profiling.h | 19 ++++-- python/tvm/runtime/profiling/__init__.py | 29 +++++++++ src/runtime/profiling.cc | 59 ++++++++++--------- .../python/unittest/test_runtime_profiling.py | 11 +++- 4 files changed, 82 insertions(+), 36 deletions(-) diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 7b9a68063f16..366f4f1deed1 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -198,13 +198,20 @@ class ReportNode : public Object { */ String AsCSV() const; /*! \brief Create a human readable table of profiling metrics. - * \param aggregate Whether or not to join multiple calls to the same op into a single line. - * \param sort Whether or not to sort call frames by descending duration. If - * false and if `aggregate` is false, frames will be sorted by order of - * appearance in the program. Order is undefined if `sort` is false and - * `aggregate` is true. + * + * \param aggregate Whether or not to join multiple calls to the + * same op into a single line. + * + * \param sort Whether or not to sort call frames by descending + * duration. If false and if `aggregate` is false, frames will + * be sorted by order of appearance in the program. Order is + * undefined if `sort` is false and `aggregate` is true. + * + * \param compute_col_sums Whether or not to include sum totals for + * the Count, Duation, and Percent columns. + * */ - String AsTable(bool sort = true, bool aggregate = true) const; + String AsTable(bool sort = true, bool aggregate = true, bool compute_col_sums = true) const; /*! \brief Convert this report to JSON. * * Output JSON will be of this format: diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py index b91fe727698b..7d40a81e498a 100644 --- a/python/tvm/runtime/profiling/__init__.py +++ b/python/tvm/runtime/profiling/__init__.py @@ -47,6 +47,35 @@ def csv(self): """ return _ffi_api.AsCSV(self) + def table(self, sort=True, aggregate=True, col_sums=True): + """Generate a human-readable table + + Parameters + ---------- + sort : bool + + If aggregate is true, whether to sort call frames by + descending duration. If aggregate is False, whether to + sort frames by order of appearancei n the program. + + aggregate : bool + + Whether to join multiple calls to the same op into a + single line. + + col_sums : bool + + Whether to include the sum of each column. + + Returns + ------- + table : str + + A human-readable table + + """ + return _ffi_api.AsTable(self, sort, aggregate, col_sums) + def json(self): """Convert this profiling report into JSON format. diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index a1d06fc8cab8..90d4ac64238f 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -30,6 +30,7 @@ #include #include +#include #include #include #include @@ -342,7 +343,7 @@ String ReportNode::AsJSON() const { return s.str(); } -String ReportNode::AsTable(bool sort, bool aggregate) const { +String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) const { // aggregate calls by op hash (or op name if hash is not set) + argument shapes std::vector> aggregated_calls; if (aggregate) { @@ -414,36 +415,38 @@ String ReportNode::AsTable(bool sort, bool aggregate) const { } // compute columnwise sums - std::unordered_map col_sums; - for (auto call : aggregated_calls) { - for (auto p : call) { - if (p.second.as()) { - int64_t val = p.second.as()->value; - auto it = col_sums.find(p.first); - if (it != col_sums.end()) { - val += it->second.as()->value; - } - col_sums[p.first] = ObjectRef(make_object(val)); - } else if (p.second.as()) { - double val = p.second.as()->microseconds; - auto it = col_sums.find(p.first); - if (it != col_sums.end()) { - val += it->second.as()->microseconds; - } - col_sums[p.first] = ObjectRef(make_object(val)); - } else if (p.second.as()) { - double val = p.second.as()->percent; - auto it = col_sums.find(p.first); - if (it != col_sums.end()) { - val += it->second.as()->percent; + if (compute_col_sums) { + std::unordered_map col_sums; + for (auto call : aggregated_calls) { + for (auto p : call) { + if (p.second.as()) { + int64_t val = p.second.as()->value; + auto it = col_sums.find(p.first); + if (it != col_sums.end()) { + val += it->second.as()->value; + } + col_sums[p.first] = ObjectRef(make_object(val)); + } else if (p.second.as()) { + double val = p.second.as()->microseconds; + auto it = col_sums.find(p.first); + if (it != col_sums.end()) { + val += it->second.as()->microseconds; + } + col_sums[p.first] = ObjectRef(make_object(val)); + } else if (p.second.as()) { + double val = p.second.as()->percent; + auto it = col_sums.find(p.first); + if (it != col_sums.end()) { + val += it->second.as()->percent; + } + col_sums[p.first] = ObjectRef(make_object(val)); } - col_sums[p.first] = ObjectRef(make_object(val)); } } + col_sums["Name"] = String("Sum"); + aggregated_calls.push_back({{String("Name"), String("----------")}}); // separator + aggregated_calls.push_back(col_sums); } - col_sums["Name"] = String("Sum"); - aggregated_calls.push_back({{String("Name"), String("----------")}}); // separator - aggregated_calls.push_back(col_sums); // per-device metrics for (auto p : device_metrics) { @@ -454,7 +457,6 @@ String ReportNode::AsTable(bool sort, bool aggregate) const { // Table formatting std::set unique_headers; - for (auto row : aggregated_calls) { for (auto p : row) { unique_headers.insert(p.first); @@ -666,6 +668,7 @@ TVM_REGISTER_OBJECT_TYPE(ReportNode); TVM_REGISTER_OBJECT_TYPE(DeviceWrapperNode); TVM_REGISTER_OBJECT_TYPE(MetricCollectorNode); +TVM_REGISTER_GLOBAL("runtime.profiling.AsTable").set_body_method(&ReportNode::AsTable); TVM_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { return n->AsCSV(); }); TVM_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) { return n->AsJSON(); diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index 3e38a526855a..b67142b42358 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -184,8 +184,15 @@ def test_report_serialization(): report = vm.profile(data, func_name="main") report2 = Report.from_json(report.json()) - # equality on reports compares pointers, so we compare the printed results instead. - assert str(report) == str(report2) + # Equality on reports compares pointers, so we compare the printed + # results instead. + + # Use .table() instead of str(), because str() includes aggregate + # and column summations whose values may be impacted by otherwise + # negligible conversion errors. (2 occurrences / 3000 trials) + assert report.table(aggregate=False, col_sums=False) == report2.table( + aggregate=False, col_sums=False + ) if __name__ == "__main__": From bb5e6533e491f8e83371f6dccc4176bbdb7f1e30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kr=C3=B6ning?= Date: Sat, 23 Oct 2021 02:17:46 +0200 Subject: [PATCH 78/84] [Tutorial] Fix formatting, grammar, dead link (#9281) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tutorial: preprocess.py: Fix leading whitespace This fixes the indentation of metadata in `preprocess.py` in the TVMC tutorial, removing the leading whitespaces in the HTML rendering[^1]. [^1] https://tvm.apache.org/docs/tutorial/tvmc_command_line_driver.html#preprocess-py * tutorial: Add missing code block escapes * tutorial: Grammar fixup * README.md: Fix link to introduction Co-authored-by: Martin Kröning --- README.md | 2 +- gallery/tutorial/autotvm_relay_x86.py | 5 ++++- gallery/tutorial/tensor_expr_get_started.py | 8 ++++---- gallery/tutorial/tvmc_command_line_driver.py | 8 ++++---- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 09ceb7ab1d07..d96038d17804 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ TVM is licensed under the [Apache-2.0](LICENSE) license. Getting Started --------------- Check out the [TVM Documentation](https://tvm.apache.org/docs/) site for installation instructions, tutorials, examples, and more. -The [Getting Started with TVM](https://tvm.apache.org/docs/tutorials/get_started/introduction.html) tutorial is a great +The [Getting Started with TVM](https://tvm.apache.org/docs/tutorial/introduction.html) tutorial is a great place to start. Contribute to TVM diff --git a/gallery/tutorial/autotvm_relay_x86.py b/gallery/tutorial/autotvm_relay_x86.py index 8b9c45c2a859..67b832cc226d 100644 --- a/gallery/tutorial/autotvm_relay_x86.py +++ b/gallery/tutorial/autotvm_relay_x86.py @@ -106,7 +106,7 @@ # TVMC has adopted NumPy's ``.npz`` format for both input and output data. # # As input for this tutorial, we will use the image of a cat, but you can feel -# free to substitute image for any of your choosing. +# free to substitute this image for any of your choosing. # # .. image:: https://s3.amazonaws.com/model-server/inputs/kitten.jpg # :height: 224px @@ -278,6 +278,7 @@ from tvm.autotvm.tuner import XGBTuner from tvm import autotvm +################################################################################ # Set up some basic parameters for the runner. The runner takes compiled code # that is generated with a specific set of parameters and measures the # performance of it. ``number`` specifies the number of different @@ -303,6 +304,7 @@ enable_cpu_cache_flush=True, ) +################################################################################ # Create a simple structure for holding tuning options. We use an XGBoost # algorithim for guiding the search. For a production job, you will want to set # the number of trials to be larger than the value of 10 used here. For CPU we @@ -426,6 +428,7 @@ for rank in ranks[0:5]: print("class='%s' with probability=%f" % (labels[rank], scores[rank])) +################################################################################ # Verifying that the predictions are the same: # # .. code-block:: bash diff --git a/gallery/tutorial/tensor_expr_get_started.py b/gallery/tutorial/tensor_expr_get_started.py index fda332cb63ba..e4d947d1c488 100644 --- a/gallery/tutorial/tensor_expr_get_started.py +++ b/gallery/tutorial/tensor_expr_get_started.py @@ -133,7 +133,7 @@ ################################################################################ # Let's run the function, and compare the output to the same computation in -# numpy. The compiled TVM function is exposes a concise C API that can be invoked +# numpy. The compiled TVM function exposes a concise C API that can be invoked # from any language. We begin by creating a device, which is a device (CPU in this # example) that TVM can compile the schedule to. In this case the device is an # LLVM CPU target. We can then initialize the tensors in our device and @@ -258,8 +258,8 @@ def evaluate_addition(func, target, optimization, log): print(tvm.lower(s, [A, B, C], simple_mode=True)) ################################################################################ -# Comparing the Diferent Schedules -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Comparing the Different Schedules +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # We can now compare the different schedules baseline = log[0][1] @@ -347,7 +347,7 @@ def evaluate_addition(func, target, optimization, log): fadd = tvm.build(s, [A, B, C], target=tgt_gpu, name="myadd") ################################################################################ - # The compiled TVM function is exposes a concise C API that can be invoked from + # The compiled TVM function exposes a concise C API that can be invoked from # any language. # # We provide a minimal array API in python to aid quick testing and prototyping. diff --git a/gallery/tutorial/tvmc_command_line_driver.py b/gallery/tutorial/tvmc_command_line_driver.py index 7a0b97895e4f..facb978cea67 100644 --- a/gallery/tutorial/tvmc_command_line_driver.py +++ b/gallery/tutorial/tvmc_command_line_driver.py @@ -174,10 +174,10 @@ # data types. For this reason, most models require some pre and # post-processing, to ensure the input is valid and to interpret the output. # TVMC has adopted NumPy's ``.npz`` format for both input and output data. This -# is a well-supported NumPy format to serialize multiple arrays into a file +# is a well-supported NumPy format to serialize multiple arrays into a file. # # As input for this tutorial, we will use the image of a cat, but you can feel -# free to substitute image for any of your choosing. +# free to substitute this image for any of your choosing. # # .. image:: https://s3.amazonaws.com/model-server/inputs/kitten.jpg # :height: 224px @@ -197,8 +197,8 @@ # requirement for the script. # # .. code-block:: python -# :caption: preprocess.py -# :name: preprocess.py +# :caption: preprocess.py +# :name: preprocess.py # # #!python ./preprocess.py # from tvm.contrib.download import download_testdata From 6219d19a682e5823fe93e152c9cee94ded021dca Mon Sep 17 00:00:00 2001 From: shengxinhu <69130386+shengxinhu@users.noreply.github.com> Date: Sat, 23 Oct 2021 08:21:14 +0800 Subject: [PATCH 79/84] [Caffe Frontend] Add support for Embed layer (#9257) * [Caffe Frontend] Add support for Embed layer * [Caffe Frontend] Add support for Embed layer --- python/tvm/relay/frontend/caffe.py | 41 ++++++++++ tests/python/frontend/caffe/test_forward.py | 88 +++++++++++++++++++++ 2 files changed, 129 insertions(+) diff --git a/python/tvm/relay/frontend/caffe.py b/python/tvm/relay/frontend/caffe.py index b8273b0324c0..be76feef7297 100644 --- a/python/tvm/relay/frontend/caffe.py +++ b/python/tvm/relay/frontend/caffe.py @@ -50,6 +50,7 @@ def __init__(self, init_layer_dict, predict_layer, exp_tab): "Deconvolution": self.convert_deconv, "Dropout": self.convert_dropout, "Eltwise": self.convert_eltwise, + "Embed": self.convert_embed, "Flatten": self.convert_flatten, "InnerProduct": self.convert_innerproduct, "Input": None, @@ -593,6 +594,46 @@ def convert_crop(self, op): out = _op.slice_like(in_expr_a_stride, in_expr_b, axes=to_crop_axis) return out + def convert_embed(self, op): + """Convert Embed layer""" + inputs = op.bottom + embed_param = op.embed_param + num_output = embed_param.num_output + input_dim = embed_param.input_dim + bias_term = embed_param.bias_term + weight_bias_blobs = self.init_layer_dict[op.name].blobs + weight, bias = None, None + if bias_term: + weight = weight_bias_blobs[0] + bias = weight_bias_blobs[1] + assert weight and bias + else: + weight = weight_bias_blobs[0] + assert weight + weight_value = np.asarray(weight.data, np.float32) + weight_value = np.reshape(weight_value, [input_dim, num_output]) + weight_expr = self.exp_tab.new_const(weight_value, dtype="float32") + in_expr = self.exp_tab.get_expr(inputs[0]) + input_shape = _infer_shape(in_expr) + input_count = 1 + for dim in input_shape: + input_count *= dim + + index = _op.cast(in_expr, "int32") + out = _op.take(weight_expr, index, axis=0) + + if bias_term: + bias_value = np.asarray(bias.data, np.float32) + bias_expr = self.exp_tab.new_const(bias_value, dtype="float32") + out = _op.reshape(out, [input_count, num_output]) + out = _op.add(out, bias_expr) + + out_shape = list(input_shape) + out_shape.append(num_output) + out = _op.reshape(out, out_shape) + + return out + def check_unsupported_ops(self): """Check unsupported Caffe ops in our converter.""" unsupported_ops_set = set() diff --git a/tests/python/frontend/caffe/test_forward.py b/tests/python/frontend/caffe/test_forward.py index f4c0cd102340..233977d66066 100644 --- a/tests/python/frontend/caffe/test_forward.py +++ b/tests/python/frontend/caffe/test_forward.py @@ -763,6 +763,94 @@ def test_forward_TanH(): _test_tanh(np.random.rand(10).astype(np.float32)) +####################################################################### +# Embed +# ----------- + + +def _test_embed(data, **kwargs): + """One iteration of Embed""" + _test_op(data, L.Embed, "Embed", **kwargs) + + +def test_forward_Embed(): + k = 20 + data = [i for i in range(k)] + np.random.shuffle(data) + # dimension is 1 + data = np.asarray(data) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=True, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=False, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + # dimension is 2 + data = np.reshape(data, [4, 5]) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=True, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=False, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + # dimension is 3 + data = np.reshape(data, [2, 2, 5]) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=True, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=False, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + # dimension is 4 + data = np.reshape(data, [2, 2, 5, 1]) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=True, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=False, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + + ####################################################################### # Mobilenetv2 # ----------- From e9a66a1a1648b7e06257cc405a3f842a78c23bd9 Mon Sep 17 00:00:00 2001 From: Yuanjing Shi Date: Fri, 22 Oct 2021 23:58:06 -0700 Subject: [PATCH 80/84] [TIR] Add structural error printing for TensorIR (#9306) * add structural error printing * remove old code * address comments * address comments * add test * fix test case * fix nested loop * rm print * change simple loop cond * address comments * fix test * address comments * remove msg * add override * address comments * address comments --- src/printer/text_printer.h | 3 + src/printer/tvmscript_printer.cc | 102 +++++++++++++++--- src/tir/schedule/error.cc | 44 ++++---- .../unittest/test_tvmscript_error_report.py | 73 +++++++++++++ 4 files changed, 191 insertions(+), 31 deletions(-) diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index a2178167b2e3..316d59631782 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -379,6 +379,9 @@ class TIRTextPrinter : public StmtFunctor, String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false); +String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate); + } // namespace tir } // namespace tvm diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 13e4cfcd30ba..8ac745f675d9 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -91,7 +91,7 @@ class TVMScriptPrinter : public StmtFunctor, */ TVM_DLL Doc Print(const ObjectRef& node); - private: + protected: /*! \brief The tir prefix */ String tir_prefix_; /*! \brief whether show meta data */ @@ -208,6 +208,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintBlockVars(const BlockRealizeNode* op); Doc PrintBlockAttr(const BlockRealizeNode* op); Doc PrintBlockBody(const BlockNode* op); + virtual Doc PrintBlockName(const BlockNode* block_op); Doc PrintBufferRegion(const BufferRegionNode* op); Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op); Doc PrintAnnotations(const Map& annotations); @@ -217,15 +218,24 @@ class TVMScriptPrinter : public StmtFunctor, Doc AllocVar(const Var& var); Doc AllocBuf(const Buffer& buffer); void TryDeallocVar(const Var& var); + bool ContainsOptionalInfo(const Stmt& stmt); /*! Helper functions for loop printing. */ /*! * \brief Print a single for loop * \param loop The for loop to be printed */ - Doc PrintLoop(const For& loop); + virtual Doc PrintLoop(const For& loop); /*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */ Doc PrintLoopStack(); + /*! + * \brief Print all simple loops in stack into one line using tir_prefix_.grid(). + * \param for_op the for node to be checked + */ + bool IsSimpleLoop(const ForNode* for_op) { + return for_op->kind == ForKind::kSerial && for_op->annotations.empty() && + is_zero(for_op->min) && !ContainsOptionalInfo(GetRef(for_op)); + } /*! * \brief Print additional info about expr in comment. @@ -234,11 +244,9 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintOptionalInfo(const Stmt& stmt) { Doc doc; // default annotations - if (annotate_ != nullptr) { + if (ContainsOptionalInfo(stmt)) { std::string annotated_stmt = annotate_(stmt); - if (!annotated_stmt.empty()) { - doc << "# " << annotated_stmt << Doc::NewLine(); - } + doc << "# " << annotated_stmt << Doc::NewLine(); } return doc; } @@ -391,6 +399,16 @@ Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) { return val; } +/*! + * \brief Check if any optional information exists in annotate_ for + * a given Stmt. + * \param stmt The statement. + */ +bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) { + if (annotate_ == nullptr) return false; + return !annotate_(stmt).empty(); +} + /*! * \brief Try to dealloc vars out of space and leave the index to coming vars. * \note It is not a necessary step. @@ -835,14 +853,14 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { var_not_in_headers_.insert(op->loop_var.get()); loop_var_map_[op->loop_var.get()] = GetRef(op); const auto* body = op->body.as(); - bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() && is_zero(op->min); + bool simple_loop = IsSimpleLoop(op); if (simple_loop) simple_loop_stack_.push_back(GetRef(op)); // It is a loop that can be compressed, let the loops below print it out - if (simple_loop && body != nullptr) { - Doc result = Print(GetRef(body)); + if (simple_loop && body != nullptr && IsSimpleLoop(body)) { + doc << Print(GetRef(body)); TryDeallocVar(op->loop_var); loop_var_map_.erase(op->loop_var.get()); - return result; + return doc; } // It is a loop that can not be compressed bool print_above = !simple_loop_stack_.empty(); @@ -916,6 +934,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } +/*! Helper functions for block printing. */ Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) { Doc doc; doc << Print(iter_var->var) << " = " << tir_prefix_ << ".axis."; @@ -1049,15 +1068,25 @@ Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) { return body; } -Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { - const auto* block_op = op->block.as(); - // print block name and block vars +/*! + * \brief Print the name of a block + * \param block_op The block node to be printed + */ +Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) { Doc doc; doc << "with " << tir_prefix_ << ".block("; if (!block_op->name_hint.empty()) { doc << Doc::StrLiteral(block_op->name_hint); } doc << "):"; + return doc; +} + +Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { + const auto* block_op = op->block.as(); + Doc doc = PrintOptionalInfo(GetRef(block_op)); + // print block name and block vars + doc << PrintBlockName(block_op); Doc block_var = PrintBlockVars(op); // print predicate, binding, read/write tensor region, annotations Doc block_attr_doc = PrintBlockAttr(op); @@ -1343,6 +1372,45 @@ Doc TVMScriptPrinter::PrintLoopStack() { return res; } +/*! + * \brief The printer for TVMScript with diagnostic + * \details The printer obtain the precedence of the top-level operation when printing each + * subexpression to decide whether or not parentheses is needed. + */ +class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { + public: + explicit TVMScriptPrinterWithDiagnostic(const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate) + : TVMScriptPrinter(tir_prefix, show_meta, annotate) {} + + protected: + Doc PrintBlockName(const BlockNode* block_op) override; + Doc PrintUnderline(const Stmt& stmt, int length); + Doc PrintLoop(const For& loop) override; +}; + +Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) { + Doc doc = TVMScriptPrinter::PrintBlockName(block_op); + doc << PrintUnderline(GetRef(block_op), doc.str().size()); + return doc; +} + +Doc TVMScriptPrinterWithDiagnostic::PrintUnderline(const Stmt& stmt, int length) { + Doc doc; + // annotation + if (ContainsOptionalInfo(stmt)) { + String underline = std::string(length, '^'); + doc << Doc::NewLine() << underline; + } + return doc; +} + +Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) { + Doc res = TVMScriptPrinter::PrintLoop(loop); + res << PrintUnderline(loop, res.str().size()); + return res; +} + String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) { ICHECK(mod->IsInstance() || mod->IsInstance()); return TVMScriptPrinter(tir_prefix, show_meta).Print(mod).str() + "\n"; @@ -1350,5 +1418,13 @@ String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_met TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript); +String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate) { + ICHECK(mod->IsInstance() || mod->IsInstance()); + return TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod).str() + "\n"; +} + +TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index d8dcf57b91e4..eb72773ffedb 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -24,29 +24,37 @@ namespace tir { String ScheduleError::RenderReport(const String& primitive) const { IRModule mod = this->mod(); std::ostringstream os; - os << "ScheduleError: An error occurred in the schedule primitive '" << primitive - << "'.\n\nThe IR is:\n" - << AsTVMScript(mod); + + // get locations of interest Array locs = LocationsOfInterest(); + std::unordered_map loc_obj_to_name; int n_locs = locs.size(); - std::vector roi_names; - roi_names.reserve(n_locs); - if (n_locs > 0) { - os << "Regions of interest:\n"; - for (const ObjectRef& obj : locs) { - String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size()); - os << name << "\n" << obj; - roi_names.emplace_back(std::move(name)); - } - os << "\n"; - } std::string msg = DetailRenderTemplate(); - for (int i = 0; i < n_locs; ++i) { - std::string src = "{" + std::to_string(i) + "}"; - for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { - msg.replace(pos, src.length(), roi_names[i]); + if (n_locs > 0) { + for (int i = 0; i < n_locs; ++i) { + std::string name = locs[i]->GetTypeKey() + '#' + std::to_string(i); + std::string src = "{" + std::to_string(i) + "}"; + for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { + msg.replace(pos, src.length(), name); + } + loc_obj_to_name.emplace(locs[i], std::move(name)); } } + + // print IR module + runtime::TypedPackedFunc annotate = + runtime::TypedPackedFunc( + [&loc_obj_to_name](const Stmt& expr) -> std::string { + auto it = loc_obj_to_name.find(Downcast(expr)); + if (it == loc_obj_to_name.end()) return ""; + return it->second; + }); + + os << "ScheduleError: An error occurred in the schedule primitive '" << primitive + << "'.\n\nThe IR with diagnostic is:\n" + << AsTVMScriptWithDiagnostic(mod, "tir", false, annotate); + + // print error message os << "Error message: " << msg; return os.str(); } diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 80c37229f519..3098c86a7c2e 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -18,6 +18,7 @@ import pytest import sys import tvm +from tvm import tir from tvm.script import tir as T from tvm.ir.diagnostics import override_renderer import inspect @@ -511,5 +512,77 @@ def render(e): # TODO(Siyuan): block iter errors. + +@T.prim_func +def elementwise_not_affine(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in T.grid(128, 128, 128, 8): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + vl = T.axis.S(128, l * 16) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@T.prim_func +def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + C = T.alloc_buffer((128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(0, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in T.serial(0, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +def test_reorder_fail_block(): + sch = tir.Schedule(elementwise_not_affine, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(l, i) + expected_sub_error_message = ( + " # tir.Block#0\n" + ' with tir.block("B"):\n' + " ^^^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_reorder_fail_nested_loop_inner(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(k, i) + expected_sub_error_message = ( + " for i in tir.serial(0, 128):\n" + " # tir.For#0\n" + " for j in tir.serial(0, 128):\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_fuse_fail_nested_loop_outer(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.fuse(k, i) + expected_sub_error_message = ( + " # tir.For#1\n" + " for i in tir.serial(0, 128):\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" + " for j in tir.serial(0, 128):\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 4c590a2c855217f4e60cfd18f8eb459e5ba467b8 Mon Sep 17 00:00:00 2001 From: Philipp van Kempen Date: Sun, 24 Oct 2021 09:35:04 +0200 Subject: [PATCH 81/84] Fix inconsistencies in graph_executor function names handling (#9255) * Clean up redundant code in graph_executor.cc How did these lines ended up here? * Fix inconsistencies in graph_executor function names handling Updates value of `TVM_CRT_MAX_STRLEN_FUNCTION_NAME` from `80` to `120` Replace all occurences of `[120]` with `[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]` to maintain consistency and make the array lengths user-configurable. Introduces `TVM_CRT_MAX_STRLEN_PARAM_NAME` used for parameter names only Adds comments to `kMaxFuncNameLength` variabe in src/relay/backend/te_compiler_cache.cc making sure that the values are kept "in sync". (sort of) See #8953 for more context. The actual bug reported there however can only be fixed by increasing the TVM_CRT_MAX_STRLEN_FUNCTION_NAME to a value larger than the maximum possible truncated function name length (including prefixes and suffices) Example: 6 ['tvmgen' prefix length] + 7 ['default' model name length] + 5 ['fused' fused function name prefix length] + 80 [truncated function name length] + 19 [length of appended hash] + 4 [Number of '_' between components] = 121 --- apps/bundle_deploy/crt_config/crt_config.h | 4 +++- .../template_project/crt_config/crt_config.h | 4 +++- .../template_project/crt_config/crt_config.h | 5 +++- include/tvm/runtime/crt/graph_executor.h | 2 +- src/relay/backend/te_compiler_cache.cc | 4 ++++ src/runtime/crt/crt_config-template.h | 5 +++- .../crt/graph_executor/graph_executor.c | 23 +++++++------------ .../internal/graph_executor/graph_executor.h | 2 +- src/runtime/micro/crt_config.h | 4 +++- 9 files changed, 31 insertions(+), 22 deletions(-) diff --git a/apps/bundle_deploy/crt_config/crt_config.h b/apps/bundle_deploy/crt_config/crt_config.h index b89bedbc6d45..3adcb2dc8d42 100644 --- a/apps/bundle_deploy/crt_config/crt_config.h +++ b/apps/bundle_deploy/crt_config/crt_config.h @@ -37,7 +37,9 @@ /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/apps/microtvm/arduino/template_project/crt_config/crt_config.h b/apps/microtvm/arduino/template_project/crt_config/crt_config.h index cf73103aff8b..b3126cfac920 100644 --- a/apps/microtvm/arduino/template_project/crt_config/crt_config.h +++ b/apps/microtvm/arduino/template_project/crt_config/crt_config.h @@ -36,7 +36,9 @@ /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/apps/microtvm/zephyr/template_project/crt_config/crt_config.h b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h index 39fe27ef3d05..c3beaed522f2 100644 --- a/apps/microtvm/zephyr/template_project/crt_config/crt_config.h +++ b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h @@ -48,7 +48,10 @@ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 + +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 diff --git a/include/tvm/runtime/crt/graph_executor.h b/include/tvm/runtime/crt/graph_executor.h index eb68ff56d230..1353d8e06e6b 100644 --- a/include/tvm/runtime/crt/graph_executor.h +++ b/include/tvm/runtime/crt/graph_executor.h @@ -36,7 +36,7 @@ struct TVMModule; /*! \brief operator attributes about tvm op */ typedef struct TVMOpParam { - char func_name[120]; + char func_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; uint32_t num_inputs; uint32_t num_outputs; uint32_t flatten_data; diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index ec87cfc98931..be5b172e6a7c 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -134,6 +134,8 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator auto outputs = this->VisitExpr(prim_func->body); auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; + // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME + // whenever the value of kMaxFuncNameLength changes if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; truncated_name << candidate_name.substr(0, kMaxFuncNameLength); @@ -394,6 +396,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> // Generate a name. auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; + // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME + // whenever the value of kMaxFuncNameLength changes if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; truncated_name << candidate_name.substr(0, kMaxFuncNameLength); diff --git a/src/runtime/crt/crt_config-template.h b/src/runtime/crt/crt_config-template.h index aa718a303744..90897a9542b6 100644 --- a/src/runtime/crt/crt_config-template.h +++ b/src/runtime/crt/crt_config-template.h @@ -49,7 +49,10 @@ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 + +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 diff --git a/src/runtime/crt/graph_executor/graph_executor.c b/src/runtime/crt/graph_executor/graph_executor.c index 34e81c7d33b1..3fea408d9760 100644 --- a/src/runtime/crt/graph_executor/graph_executor.c +++ b/src/runtime/crt/graph_executor/graph_executor.c @@ -77,7 +77,7 @@ int NodeEntry_Load(TVMGraphExecutorNodeEntry* entry, JSONReader* reader) { void TVMGraphExecutorNode_LoadAttrs(TVMGraphExecutorNode* node, JSONReader* reader, TVMOpParam* param) { int bitmask = 0; - char key[20], value[120]; + char key[20], value[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; memset(param, 0, sizeof(TVMOpParam)); memset(key, 0, sizeof(key)); memset(value, 0, sizeof(value)); @@ -796,13 +796,13 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl char* names = NULL; DLDevice dev = {kDLCPU, 0}; tvm_crt_error_t err = TVMPlatformMemoryAllocate( - TVM_CRT_MAX_STRLEN_FUNCTION_NAME * executor->nodes_count, dev, (void**)&names); + TVM_CRT_MAX_STRLEN_PARAM_NAME * executor->nodes_count, dev, (void**)&names); if (err != kTvmErrorNoError) { fprintf(stderr, "memory allocate error: %08x", err); status = -1; return status; } - memset(names, 0, TVM_CRT_MAX_STRLEN_FUNCTION_NAME * executor->nodes_count); + memset(names, 0, TVM_CRT_MAX_STRLEN_PARAM_NAME * executor->nodes_count); uint64_t names_count; int idx; memcpy(&names_count, bptr, sizeof(names_count)); @@ -811,11 +811,11 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl uint64_t name_length; memcpy(&name_length, bptr, sizeof(name_length)); bptr += sizeof(name_length); - if (name_length >= TVM_CRT_MAX_STRLEN_FUNCTION_NAME) { + if (name_length >= TVM_CRT_MAX_STRLEN_PARAM_NAME) { fprintf(stderr, "Error: function name longer than expected.\n"); status = -1; } - memcpy(names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx, bptr, name_length); + memcpy(names + TVM_CRT_MAX_STRLEN_PARAM_NAME * idx, bptr, name_length); bptr += name_length; } @@ -831,9 +831,9 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl for (idx = 0; idx < size; idx++) { int32_t in_idx = - TVMGraphExecutor_GetInputIndex(executor, names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx); + TVMGraphExecutor_GetInputIndex(executor, names + TVM_CRT_MAX_STRLEN_PARAM_NAME * idx); CHECK_GT(in_idx, 0, "Found param for non-existent input: %s\n", - names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx); + names + TVM_CRT_MAX_STRLEN_PARAM_NAME * idx); uint32_t eid = TVMGraphExecutor_GetEntryId(executor, executor->input_nodes[in_idx], 0); if (!(eid < executor->data_entry_count)) { fprintf(stderr, "`entry_id`=%d is greater than expected(%d).\n", eid, @@ -859,7 +859,7 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl #if TVM_CRT_DEBUG TVMNDArray* entry = &(executor->data_entry[eid]); printf("loading: param %s loaded, in_idx=%d, eid=%d, ndim=%d, data[0]=%f\n", - names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx, in_idx, eid, entry->dl_tensor.ndim, + names + TVM_CRT_MAX_STRLEN_PARAM_NAME * idx, in_idx, eid, entry->dl_tensor.ndim, ((float*)entry->dl_tensor.data)[0]); // NOLINT(*) #endif // TVM_CRT_DEBUG } @@ -1181,13 +1181,6 @@ int TVMGraphExecutor_Init(TVMGraphExecutor* executor, const char* graph_json, return status; } status = TVMGraphExecutor_SetupOpExecs(executor); - if (status != 0) { - if (status != 0) { - return status; - } - - return status; - } return status; } diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h index c67c43357363..d4429308b650 100644 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h +++ b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h @@ -60,7 +60,7 @@ typedef struct TVMGraphExecutorNode { // operator type in string char op_type[16]; // name of the op - char name[120]; + char name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; // parameters TVMOpParam param; // inputs diff --git a/src/runtime/micro/crt_config.h b/src/runtime/micro/crt_config.h index c3e8fea1ba08..602060de1b4a 100644 --- a/src/runtime/micro/crt_config.h +++ b/src/runtime/micro/crt_config.h @@ -37,7 +37,9 @@ /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 From 5e62db54ad74148e6ffe36af2839cdb855d81940 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 25 Oct 2021 05:24:40 +0800 Subject: [PATCH 82/84] [TVMScript] Parser for Lambdas, Parser/Printer for `CommReducer` (#9358) * CommReducer Parser/Printer * update argmax unit test * update doc * lint fix * add unit tests with multiple reducers --- include/tvm/tir/var.h | 6 ++ python/tvm/script/parser.py | 25 ++++++++ python/tvm/script/tir/intrin.py | 18 ++++++ python/tvm/tir/__init__.py | 2 +- python/tvm/tir/expr.py | 2 +- src/printer/tvmscript_printer.cc | 48 +++++++++++---- src/tir/ir/expr.cc | 41 +++++++++++++ .../unittest/test_tvmscript_roundtrip.py | 60 +++++++++++++++++++ 8 files changed, 188 insertions(+), 14 deletions(-) diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 65c5c12a701b..40a0d1ab2f74 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -109,6 +109,12 @@ class Var : public PrimExpr { * \return the new Var copy */ TVM_DLL Var copy_with_suffix(const String& suffix) const; + /*! + * \brief Make a new copy of the variable with specified dtype + * \param dtype The specified dtype + * \return The new variable + */ + TVM_DLL Var copy_with_dtype(DataType dtype) const; /*! * \brief Get pointer to the internal value. diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 8610d91e9f07..080aa0476bec 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -490,6 +490,31 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: self.context.exit_scope() return func + def transform_Lambda(self, node): + """Lambda visitor + + Return an array of input parameters and the transformed lambda body. + """ + + self.context.enter_scope(nodes=[node.body]) + + # add parameters of the lambda + arg_vars = [] + for arg in node.params: + arg_var = tvm.te.var(arg.name) + arg_vars.append(arg_var) + self.context.update_symbol(arg.name, arg_var, node) + + # the body of a lambda must be an expr + if not isinstance(node.body, ast.Expr): + self.report_error("The body of a lambda must be an expression", node.span) + + # transform the body of the lambda + body = self.transform(node.body) + + self.context.exit_scope() + return arg_vars, body + def transform_Assign(self, node): """Assign visitor AST abstract grammar: diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index 4d7fe80b28b1..2e800355bef6 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -16,6 +16,7 @@ # under the License. """TVM Script Parser Intrinsic Classes""" # pylint: disable=redefined-builtin, relative-beyond-top-level +import builtins from typing import List, Any import tvm.tir @@ -211,3 +212,20 @@ def store(var, index, value, predicate=True, span=None): return tvm.tir.Store(var, value, index, predicate, span) super().__init__(store, stmt=True) + + +@register +def comm_reducer(lambda_io, identities, span): + """Create a CommReducer from lambda inputs/outputs and the identities""" + lambda_input = lambda_io[0] + lambda_output = lambda_io[1] + + num_args = len(lambda_input) + num_arg_per_group = num_args // 2 + x = [lambda_input[i] for i in builtins.range(0, num_arg_per_group)] + y = [lambda_input[i] for i in builtins.range(num_arg_per_group, num_args)] + + if not isinstance(lambda_output, tuple): + lambda_output = (lambda_output,) + + return tvm.tir.CommReducer(x, y, lambda_output, identities, span) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 44006239acfd..428403a98f16 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -25,7 +25,7 @@ from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle -from .expr import Call, CallEffectKind, Let, IterVar, Any +from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 2bfa0aacb184..27cf5351a077 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -442,7 +442,7 @@ def __init__(self, dom, var, iter_type, thread_tag="", span=None): @tvm._ffi.register_object("tir.CommReducer") class CommReducer(Object): - """Communicative reduce operator + """Commutative reduce operator Parameters ---------- diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 8ac745f675d9..d82ad74fd5c3 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -119,8 +119,6 @@ class TVMScriptPrinter : public StmtFunctor, std::unordered_map memo_buf_; /*! \brief Map from Buffer to Declaration Doc */ std::unordered_map memo_buf_decl_; - /*! \brief Map from CommReducer to Doc */ - std::unordered_map memo_reducer_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; /*! \brief number of children of current node's parent */ @@ -211,6 +209,7 @@ class TVMScriptPrinter : public StmtFunctor, virtual Doc PrintBlockName(const BlockNode* block_op); Doc PrintBufferRegion(const BufferRegionNode* op); Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op); + Doc PrintCommReducer(const CommReducerNode* op); Doc PrintAnnotations(const Map& annotations); static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } @@ -445,6 +444,39 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) { return doc; } +Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) { + Doc doc; + int n_var = static_cast(op->rhs.size()); + + doc << tir_prefix_ << ".comm_reducer(lambda "; + for (const Var& v_lhs : op->lhs) { + doc << Print(v_lhs) << ", "; + } + for (int i = 0; i < n_var; ++i) { + doc << Print(op->rhs[i]) << (i == n_var - 1 ? ": " : ", "); + } + if (n_var == 1) { + doc << Print(op->result[0]) << ", "; + } else { + doc << "("; + for (int i = 0; i < n_var; ++i) { + doc << Print(op->result[i]); + if (i != n_var - 1) { + doc << ", "; + } + } + doc << "), "; + } + doc << Print(op->identity_element) << ")"; + + // Remove the vars in `lhs` and `rhs`, because they are the parameters of the printed lambda. + for (int i = 0; i < n_var; ++i) { + memo_var_.erase(op->lhs[i]); + memo_var_.erase(op->rhs[i]); + } + return doc; +} + Doc TVMScriptPrinter::Print(const ObjectRef& node) { if (!node.defined()) return Doc::Text("None"); if (node->IsInstance()) { @@ -472,6 +504,8 @@ Doc TVMScriptPrinter::Print(const ObjectRef& node) { return PrintBufferRegion(node.as()); } else if (node->IsInstance()) { return PrintMatchBufferRegion(node.as()); + } else if (node->IsInstance()) { + return PrintCommReducer(node.as()); } else { LOG(FATAL) << "Do not know how to print " << node->GetTypeKey(); return Doc(); @@ -1153,7 +1187,6 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { memo_var_.clear(); memo_buf_.clear(); memo_buf_decl_.clear(); - memo_reducer_.clear(); var_not_in_headers_.clear(); buf_not_in_headers_.clear(); // print signature @@ -1178,15 +1211,6 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { body << Print((*it).first) << ", " << memo_buf_decl_[(*it).second]; body << ")" << Doc::NewLine(); } - // print comm_reducer - for (const auto& it : memo_reducer_) { - body << it.second << " = .comm_reducer("; - var_not_in_headers_.insert(it.first->lhs[0].get()); - var_not_in_headers_.insert(it.first->rhs[0].get()); - body << "lambda " << Print(it.first->lhs[0]) << ", " << Print(it.first->rhs[0]) << ": " - << Print(it.first->result[0]) << ", " << Print(it.first->identity_element[0]); - body << ")" << Doc::NewLine(); - } // print body body << "# body" << Doc::NewLine(); if (op->body->IsInstance() && diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index afc5c36ebb92..1d7c959d990d 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -90,6 +90,18 @@ Var Var::copy_with_suffix(const String& suffix) const { return Var(new_ptr); } +Var Var::copy_with_dtype(DataType dtype) const { + const VarNode* node = get(); + ObjectPtr new_ptr; + if (auto* ptr = this->as()) { + new_ptr = make_object(*ptr); + } else { + new_ptr = make_object(*node); + } + new_ptr->dtype = std::move(dtype); + return Var(new_ptr); +} + TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type, Span span) { if (type.IsObjectRef()) { @@ -904,6 +916,35 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // CommReducer CommReducer::CommReducer(Array lhs, Array rhs, Array result, Array identity_element, Span span) { + size_t n_group = result.size(); + CHECK_EQ(lhs.size(), n_group) << "ValueError: The number of vars in `lhs` must equal to the " + "number of elements in `results`"; + CHECK_EQ(rhs.size(), n_group) << "ValueError: The number of vars in `rhs` must equal to the " + "number of elements in `results`"; + CHECK_EQ(identity_element.size(), n_group) + << "ValueError: The number of identities must equal to the number of elements in `results`"; + + // Change the dtype of input vars to adapt to the dtype of identities + ArrayNode* p_lhs = lhs.CopyOnWrite(); + ArrayNode* p_rhs = rhs.CopyOnWrite(); + std::unordered_map var_map; + var_map.reserve(n_group * 2); + for (int i = 0; i < static_cast(n_group); ++i) { + DataType dtype = identity_element[i].dtype(); + Var l = lhs[i].copy_with_dtype(dtype); + Var r = rhs[i].copy_with_dtype(dtype); + var_map[lhs[i].get()] = l; + var_map[rhs[i].get()] = r; + + p_lhs->SetItem(i, l); + p_rhs->SetItem(i, r); + } + + ArrayNode* p_result = result.CopyOnWrite(); + for (int i = 0; i < static_cast(n_group); ++i) { + p_result->SetItem(i, Substitute(result[i], var_map)); + } + auto node = make_object(); node->lhs = lhs; node->rhs = rhs; diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 7c54cdc85f82..93b052ee1d96 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3095,5 +3095,65 @@ def test_primfunc_with_allocate_annotations(): tvm.ir.assert_structural_equal(func, rt_func, True) +# fmt: off +@T.prim_func +def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + threadIdx_x = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128, 128], dtype="float32") + for i in T.serial(0, 128): + T.launch_thread(threadIdx_x, 128) + reduce_temp0 = T.allocate([1], "float32", "local") + with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), T.load("float32", A.data, i * 128 + threadIdx_x), True, reduce_temp0, threadIdx_x, dtype="handle")) + + +@T.prim_func +def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + threadIdx_x = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128, 128], dtype="float32") + for i in T.serial(0, 128): + T.launch_thread(threadIdx_x, 128) + reduce_temp0 = T.allocate([1], "float32", "local") + with T.attr(T.comm_reducer(lambda x0, x1, y0, y1: (T.Select((x1 >= y1), x0, y0), T.Select((x1 >= y1), x1, y1)), [T.int32(-1), T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), T.load("float32", A.data, i * 128 + threadIdx_x), True, reduce_temp0, threadIdx_x, dtype="handle")) + + +@T.prim_func +def multiple_commreducer() -> None: + normal_reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_maxelem_cross_thread_reduction"): + T.attr(T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle")) + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_expsum_cross_thread_reduction"): + T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle")) +# fmt: on + + +def test_primfunc_with_single_reduce_group_commreducer(): + func = comm_reducer_single_reduce_group + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_primfunc_with_multiple_reduce_group_commreducer(): + func = comm_reducer_multiple_reduce_groups + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_primfunc_with_multiple_commreducer(): + func = multiple_commreducer + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From bdb311b93e88fd18825f3b426c080ff999ef216e Mon Sep 17 00:00:00 2001 From: Yaoyao Ding <994865602@qq.com> Date: Sun, 24 Oct 2021 18:28:31 -0400 Subject: [PATCH 83/84] [Fixbug] Report duplicated param names of relay function when bind params (#9350) * [Fixbug] Report duplicated param names of relay function when bind params * add test * lint --- src/relay/backend/utils.h | 2 +- tests/python/relay/test_ir_bind.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index febb550d45c0..a647aa1a3fd2 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -327,7 +327,7 @@ inline relay::Function BindParamsByName( for (auto arg : func->params) { const auto& name = arg->name_hint(); if (name_dict.count(name)) { - repeat_var.insert(arg); + repeat_var.insert(name_dict[name]); } else { name_dict[name] = arg; } diff --git a/tests/python/relay/test_ir_bind.py b/tests/python/relay/test_ir_bind.py index b179096a0528..0ab0122fa798 100644 --- a/tests/python/relay/test_ir_bind.py +++ b/tests/python/relay/test_ir_bind.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. """ test bind function.""" +import pytest import tvm from tvm import te from tvm import relay +from tvm import TVMError def test_bind_params(): @@ -34,5 +36,16 @@ def test_bind_params(): assert tvm.ir.structural_equal(zbinded, zexpected) +def test_bind_duplicated_params(): + a = relay.var("a", shape=(1,)) + aa = relay.var("a", shape=(1,)) + s = a + aa + func = relay.Function([a, aa], s) + + with pytest.raises(TVMError): + relay.build_module.bind_params_by_name(func, {"a": [1.0]}) + + if __name__ == "__main__": test_bind_params() + test_bind_duplicated_params() From aa38997824235f43a9793e7d15fb8a8b5532e8fb Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 25 Oct 2021 10:14:38 +0800 Subject: [PATCH 84/84] [BugFix][TIR] Fix primitive `Bind` for init-inside blocks (#9359) * [BugFix][TIR] Fix primitive `Bind` for init-inside blocks * fix python black error --- src/tir/schedule/primitive/for_kind.cc | 5 ++ .../unittest/test_tir_schedule_for_kind.py | 46 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 008d47792f69..55869e12b6b2 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -121,6 +121,11 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind runtime::ThreadScope thread_scope) { PreOrderVisit(loop, [&](const ObjectRef& node) { if (const auto* realize = node.as()) { + // If this block doesn't have corresponding StmtSRef in the schedule state, it must be a block + // inside `tir.init()`. We don't check the condition for such blocks. + if (!self->stmt2ref.count(realize->block.get())) { + return false; + } CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, GetRef(realize), thread_scope); } diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index 9075e93b9d45..93876c668913 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -239,6 +239,44 @@ def opaque_block(a: T.handle) -> None: A[i + 1] = A[i + 1] + A[i] +@T.prim_func +def block_inside_init(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + for i in T.serial(0, 128): + with T.block("outer"): + vi = T.axis.S(128, i) + with T.init(): + for j in T.serial(0, 128): + with T.block("init"): + vj = T.axis.S(128, j) + B[vi, vj] = 0.0 + for k in T.serial(0, 128): + for j in T.serial(0, 128): + with T.block("inner"): + vj, vk = T.axis.remap("SR", [j, k]) + B[vi, vj] = B[vi, vj] + A[vi, vj, vk] + + +@T.prim_func +def thread_bound_block_inside_init(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + for i in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("outer"): + vi = T.axis.S(128, i) + with T.init(): + for j in T.serial(0, 128): + with T.block("init"): + vj = T.axis.S(128, j) + B[vi, vj] = 0.0 + for k in T.serial(0, 128): + for j in T.serial(0, 128): + with T.block("inner"): + vj, vk = T.axis.remap("SR", [j, k]) + B[vi, vj] = B[vi, vj] + A[vi, vj, vk] + + # pylint: enable=no-member,invalid-name,unused-variable @@ -361,5 +399,13 @@ def test_bind_after_bind(): verify_trace_roundtrip(s, mod=element_wise) +def test_block_inside_init(): + s = tir.Schedule(block_inside_init, debug_mask="all") + (i,) = s.get_loops(s.get_block("outer")) + s.bind(i, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], thread_bound_block_inside_init) + verify_trace_roundtrip(s, mod=block_inside_init) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))