From 022d78f9cad19635b3b63cfbf9eac9243d04e952 Mon Sep 17 00:00:00 2001 From: ziheng Date: Sat, 16 Jan 2021 09:51:04 -0800 Subject: [PATCH] [TIR] Support Return in TIR (#7084) --- include/tvm/tir/builtin.h | 4 ++ include/tvm/tir/op.h | 9 ++++ include/tvm/tir/op_attr_types.h | 6 ++- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 28 ++++++++--- src/target/llvm/codegen_llvm.cc | 12 +++++ src/tir/op/builtin.cc | 4 ++ src/tir/op/op.cc | 4 ++ src/tir/transforms/make_packed_api.cc | 66 +++++++++++++++++++++++++- tests/python/unittest/test_tir_base.py | 60 +++++++++++++++++++++++ 10 files changed, 185 insertions(+), 10 deletions(-) create mode 100644 tests/python/unittest/test_tir_base.py diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index a150595ab551..6a40d86b8984 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -41,6 +41,10 @@ namespace tir { /*! \brief Collection of builtin intrinsics as ops */ namespace builtin { +/*! + * \brief Return value. + */ +TVM_DLL const Op& ret(); /*! * \brief Reinterpret the value using the target type. */ diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 4a907fca951d..b5a62c907ed6 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -70,6 +70,15 @@ TVM_DLL Type GetType(const PrimExpr& expr); */ TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type); +/*! + * \brief Return the value. + * + * \param value The returned value. + * \param span The location of this operation in the source. + * \return The return expression. + */ +TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span()); + /*! * Query the maximum possible value of dtype. * \param dtype The data type. diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index ec7fc172cde8..3dcc4b943a79 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -74,7 +74,11 @@ enum class CallEffectKind : int { /*! * \brief Embed opaque information in the Expr, cannot be codegen. */ - kEmbedInfo = 5 + kEmbedInfo = 5, + /*! + * \brief Function that changes control flow + */ + kControlJump = 6, }; /*! \brief Use integer to record the kind. */ diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 1aac55fa9920..901c89ed9106 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -35,7 +35,7 @@ from .function import PrimFunc from .op import call_packed, call_intrin, call_pure_extern, call_extern -from .op import call_llvm_intrin, call_llvm_pure_intrin, all, any, min_value, max_value, trace +from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp from .op import sin, sinh, asin, asinh from .op import cos, cosh, acos, acosh diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index ca61be4fcd83..182264f0db92 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -221,6 +221,22 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): ) +def ret(val): + """Create a tir return expression + + Parameters + ---------- + val : Expr + The returned tir expression, whose data type is int, float or void pointer. + + Returns + ------- + ret : PrimExpr + The return expression + """ + return call_intrin(val.dtype, "tir.ret", val) + + def any(*args, span=None): """Create a new experssion of the union of all conditions in the arguments @@ -241,10 +257,10 @@ def any(*args, span=None): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - ret = _ffi_api._OpOr(args[0], args[1], span) + val = _ffi_api._OpOr(args[0], args[1], span) for i in range(2, len(args)): - ret = _ffi_api._OpOr(ret, args[i], span) - return ret + val = _ffi_api._OpOr(val, args[i], span) + return val def all(*args, span=None): @@ -268,10 +284,10 @@ def all(*args, span=None): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - ret = _ffi_api._OpAnd(args[0], args[1], span) + val = _ffi_api._OpAnd(args[0], args[1], span) for i in range(2, len(args)): - ret = _ffi_api._OpAnd(ret, args[i], span) - return ret + val = _ffi_api._OpAnd(val, args[i], span) + return val @tvm._ffi.register_func("tvm.default_trace_action") diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 70f094a186e7..34f3897cce88 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -927,6 +927,18 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { value->addIncoming(then_value, then_value_block); value->addIncoming(else_value, else_value_block); return value; + } else if (op->op.same_as(builtin::ret())) { + auto const* val = op->args[0].as(); + ICHECK(val) << "the tir.ret should be transformed to return zero " + << "before the llvm code generation."; + ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to " + << "return zero before the llvm code generation."; + builder_->CreateRet(ConstInt32(0)); + // LLVM allows exactly one terminator in a single basic block + // append a new dummy basic block to avoid error. + llvm::BasicBlock* ret_dummy = llvm::BasicBlock::Create(*ctx_, "ret_dummy", function_); + builder_->SetInsertPoint(ret_dummy); + return ret_dummy; } else if (op->op.same_as(builtin::reinterpret())) { llvm::Type* target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 796b113a4054..1117571c8b75 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -42,6 +42,10 @@ TIR_DEFINE_BUILTIN_FUNC(reinterpret) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(ret) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) + .set_num_inputs(1); + TIR_DEFINE_BUILTIN_FUNC(likely) .set_num_inputs(1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation)) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index b576fe4faee8..9fcb07149d19 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -145,6 +145,10 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } } +PrimExpr ret(PrimExpr value, Span span) { + return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); +} + // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { using namespace tir; diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 7c4a8ef92724..adbe78a6d627 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -41,6 +41,67 @@ namespace tvm { namespace tir { +class ReturnRewriter : public StmtMutator { + public: + explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} + + Stmt VisitStmt_(const ForNode* node) override { + if (node->for_type == ForType::Parallel) in_parallel_ += 1; + Stmt ret = StmtMutator::VisitStmt_(node); + if (node->for_type == ForType::Parallel) in_parallel_ -= 1; + return ret; + } + + Stmt VisitStmt_(const EvaluateNode* node) override { + Stmt ret = StmtMutator::VisitStmt_(node); + const EvaluateNode* eval = ret.as(); + ICHECK(eval); + if (const CallNode* call = eval->value.as()) { + if (call->op.same_as(builtin::ret())) { + ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope."; + ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument."; + ret = WriteToOut(call->args[0], ret_var_, ret_tcode_); + } + } + return ret; + } + + private: + std::pair ConvertForFFI(PrimExpr val) { + // convert val's data type to FFI data type, return type code + DataType dtype = val.dtype(); + if (dtype.is_int() || dtype.is_uint()) { + return {kTVMArgInt, Cast(DataType::Int(64), val)}; + } else if (dtype.is_float()) { + return {kTVMArgFloat, Cast(DataType::Float(64), val)}; + } else if (dtype.is_void()) { + return {kTVMNullptr, val}; + } else { + LOG(FATAL) << "data type " << dtype << " not supported yet"; + } + return {kTVMNullptr, val}; + } + + Stmt WriteToOut(PrimExpr val, Var ret_var, Var ret_tcode) { + auto p = ConvertForFFI(val); + int tcode = p.first; + val = p.second; + Stmt store_val = Store(ret_var_, val, 0, const_true()); + Stmt store_tcode = Store(ret_tcode_, tcode, 0, const_true()); + Stmt ret_zero = Evaluate(tvm::ret(0)); + return SeqStmt({store_val, store_tcode, ret_zero}); + } + + Var ret_var_; + Var ret_tcode_; + int in_parallel_{0}; +}; + +Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) { + ReturnRewriter rewriter(ret_var, ret_tcode); + return rewriter(body); +} + inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } @@ -182,8 +243,9 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); } - Stmt body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, - StringImm(name_hint + "_compute_"), func_ptr->body); + Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode); + body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, + StringImm(name_hint + "_compute_"), body); // Set device context if (vmap.count(device_id.get())) { PrimExpr node = StringImm("default"); diff --git a/tests/python/unittest/test_tir_base.py b/tests/python/unittest/test_tir_base.py new file mode 100644 index 000000000000..6e081a179059 --- /dev/null +++ b/tests/python/unittest/test_tir_base.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir +from tvm.ir.transform import PassContext + + +def build_tir_func(func): + func = func.with_attr("global_symbol", "main") + pass_ctx = PassContext.current() + if pass_ctx.config.get("tir.noalias", True): + func = func.with_attr("tir.noalias", True) + mod = tvm.IRModule({"main": func}) + func = tvm.build(mod) + return func + + +def test_scalar_add(): + a = tir.Var("a", "float32") + b = tir.Var("b", "float32") + c = a + b + c = tir.ret(c) + c = tir.Evaluate(c) + func = tir.PrimFunc([a, b], c) + func = build_tir_func(func) + out = func(1.0, 2.0) + assert out == 3.0 + + +def test_control_flow_jump(): + ib = tvm.tir.ir_builder.create() + a = tir.Var("a", "float32") + b = tir.Var("b", "float32") + with ib.if_scope(True): + ib.emit(tir.Evaluate(tir.ret(a))) + ib.emit(tir.Evaluate(tir.ret(b))) + stmt = ib.get() + func = tir.PrimFunc([a, b], stmt) + func = build_tir_func(func) + out = func(1.0, 2.0) + assert out == 1.0 + + +if __name__ == "__main__": + test_scalar_add() + test_control_flow_jump()