diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 2113d58f1ffa7..5ee847e2f0109 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -418,6 +418,11 @@ TVM_DLL Pass ConvertBlocksToOpaque(); */ TVM_DLL Pass CompactBufferAllocation(); +/*! + * This pass legalizes packed calls by wrapping their arguments into TVMValues + */ +TVM_DLL Pass LegalizePackedCalls(); + /*! * \brief Flatten the multi-dimensional BufferLoad and BufferStore * to single dimensional Load/Store. Also remove Block to diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 1dfa09ffcce9c..1e3f101703c82 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -269,50 +269,11 @@ class AOTExecutorCodegen : public ExprVisitor { } auto sid_value = sids_table_[sid]; - if (!use_unpacked_api_) { - // Pack the sid inside the TVMValue - auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle()); - tvm::PrimExpr set_tensor = - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrData, sid_value}); - stmts_.push_back( - tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor))); - buffer_vars.push_back(sid_array); - } else { - buffer_vars.push_back(sid_value); - } + buffer_vars.push_back(sid_value); } return buffer_vars; } - /*! - * \brief Utility function to return a parameter associated with an expression - * \param expr Relay Expression associated with the parameter - * \return Variable that represents the DLTensor associated with the parameters - */ - tir::Var PackParam(Expr expr) { - int param_sid = param_storage_ids_[params_by_expr_[expr]]; - auto param_array = te::Var(MakeString("param_", param_sid, "_array"), DataType::Handle()); - - // Compose the lookup_call using a local stack - Array lookup_call; - // Set the param to the value returned by lookup_call - auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(params_by_expr_[expr])}); - - if (!use_unpacked_api_) { - tvm::PrimExpr set_param_array = - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {param_array, 0, tir::builtin::kArrData, param_handle}); - stmts_.push_back( - tir::LetStmt(param_array, StackAlloca("arg_value", 1), tir::Evaluate(set_param_array))); - } else { - stmts_.push_back(tir::LetStmt(param_array, param_handle, tir::Evaluate(0))); - } - - return param_array; - } - /*! * brief Given an expression return the variable(s) associated with that expression */ @@ -322,9 +283,6 @@ class AOTExecutorCodegen : public ExprVisitor { // Input variable int main_index = std::distance(input_vars_.begin(), input_iter); return {main_signature_[main_index]}; - } else if (params_by_expr_.find(arg) != params_by_expr_.end()) { - // Parameter of the network - return {PackParam(arg)}; } else { // Storage identifier (i.e., intermediate memory) return PackSid(arg); @@ -340,8 +298,14 @@ class AOTExecutorCodegen : public ExprVisitor { // Pack the inputs for (Expr arg : call->args) { - auto var_arg = FindExpr(arg); - args.push_back(var_arg[0]); + if (params_by_expr_.find(arg) != params_by_expr_.end()) { + auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), + {tir::StringImm(params_by_expr_[arg])}); + args.push_back(param_handle); + } else { + auto var_arg = FindExpr(arg); + args.push_back(var_arg[0]); + } } auto ret_expr = Downcast(call); @@ -369,7 +333,7 @@ class AOTExecutorCodegen : public ExprVisitor { * TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a * copy-on-write fashion. */ - void CopyToOutput(te::Var out, te::Var in, size_t size) { + void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) { // Define intermediate DLTensor to load/store the data auto tmp0 = te::Var("tmp0", DataType::Handle()); auto tmp1 = te::Var("tmp1", DataType::Handle()); @@ -381,10 +345,15 @@ class AOTExecutorCodegen : public ExprVisitor { PrimExpr tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), {out, 0, tir::builtin::kArrData}); if (use_unpacked_api_) { - retval_get = in; tostore = out; } + // Do not pack the input if the flag is set or the caller + // explicitly asked to do so (e.g., copying a param to the output) + if (use_unpacked_api_ || !pack_input) { + retval_get = in; + } + // Copy the variable from the input to the output tir::Stmt copy = tir::For( loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial, @@ -563,9 +532,16 @@ class AOTExecutorCodegen : public ExprVisitor { auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), buffers[0].sid); if (output_iter != return_sid_.end()) { int output_index = std::distance(return_sid_.begin(), output_iter); - auto var_expr = FindExpr(expr); - CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0], - buffers[0].size_bytes); + if (params_by_expr_.find(expr) != params_by_expr_.end()) { + auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), + {tir::StringImm(params_by_expr_[expr])}); + CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle, + /*pack_input*/ true, buffers[0].size_bytes); + } else { + auto var_expr = FindExpr(expr); + CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0], + /*pack_input*/ true, buffers[0].size_bytes); + } } } @@ -584,7 +560,9 @@ class AOTExecutorCodegen : public ExprVisitor { auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), buffers[0].sid); if (output_iter != return_sid_.end()) { int output_index = std::distance(return_sid_.begin(), output_iter); - CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr), + auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), + {tir::StringImm(params_by_expr_[expr])}); + CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle, false, buffers[0].size_bytes); } } @@ -625,7 +603,9 @@ class AOTExecutorCodegen : public ExprVisitor { throw std::invalid_argument("match case not yet implemented"); } - // Create the main PrimFunc to execute the graph + // Create the main PrimFunc to execute the graph. Please note that + // the packed function calls don't pack their arguments. The AOT + // runner function needs to be legalized by the LegalizePackedCalls pass. tir::PrimFunc CreateMainFunc(unsigned int relay_params) { tir::Stmt body = tir::SeqStmt(stmts_); @@ -757,6 +737,9 @@ class AOTExecutorCodegen : public ExprVisitor { VisitExpr(func->body); + // Create the runner function. Please note that the function is not legal yet + // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need + // to run the LegalizePackedCalls pass. auto prim_func = CreateMainFunc(func->params.size()); UpdateMainWorkspaceSize(prim_func, func); LoweredOutput ret; @@ -787,6 +770,13 @@ class AOTExecutorCodegen : public ExprVisitor { auto storage_rewrite = tir::transform::StorageRewrite(); mod_run = storage_rewrite(mod_run); + // Legalize AOT if needed. This means that all the packed calls + // need to be wrapped in TVMValues (unless use_unpacked_api is set) + if (!use_unpacked_api_) { + auto pack_calls = tir::transform::LegalizePackedCalls(); + mod_run = pack_calls(mod_run); + } + // Update the lowered functions auto target_host_str = target_host_->str(); if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 3b4e693b820a8..648de2792b021 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -161,6 +161,27 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { return align; } +/*! + * \brief Create an int32 constant + * \param index the value of the constant + * \return the PrimExpr that represents the constant + */ +inline PrimExpr ConstInt32(size_t index) { + ICHECK_LE(index, std::numeric_limits::max()); + return make_const(DataType::Int(32), static_cast(index)); +} + +/*! + * \brief Allocate TVMValues on the stack + * \param type type of allocation + * \param num number of TVMValues to allocate + * \return PrimExpr representing the TVMValue + */ +inline PrimExpr StackAlloca(std::string type, size_t num) { + Array args = {StringImm(type), ConstInt32(num)}; + return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args); +} + /*! * \brief Convert a IR node to be SSA form. * \param stmt The source statement to be converted. diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc new file mode 100644 index 0000000000000..70dfec161c3b1 --- /dev/null +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file make_packed_call.cc + * \brief Rewrite packed calls in AOT so that the arguments are packed + */ +#include +#include +#include +#include +#include +#include + +#include + +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +using InputMap = + std::unordered_map; +/** + * This is a legalization pass only used in AOT. Traverse the TIR graph to legalize + * packed calls by making its argument wrapped in TVMValues (by using tvm_set_struct built-in) + */ +class PackedCallLegalizer : public StmtExprMutator { + public: + Stmt Legalize(const InputMap& params, tir::Stmt body) { + inputs_ = params; + return StmtExprMutator::VisitStmt(body); + } + + Stmt VisitStmt_(const EvaluateNode* op) final { + if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op); + const CallNode* call = op->value.as(); + // Given a packed call f(A,B,C), we need a set of new statements + // let A_packed = set_struct(tvm_value1, A) + // let B_packed = set_struct(tvm_value2, B) + // let C_packed = set_struct(tvm_value3, C) + // call_packed(f, A_packed, B_packed, C_packed) + std::vector new_stmts; + if (call) { + if (call->op.same_as(builtin::tvm_call_cpacked())) { + Array packed_args{call->args[0]}; + for (unsigned i = 1; i < call->args.size(); i++) { + // No need to pack inputs of the prim_func + if (inputs_[call->args[i]] == true) { + packed_args.push_back(call->args[i]); + } else { + // Pack the argument inside a TVMValue + auto sid_array = tir::Var("tvm_value", DataType::Handle()); + tir::Stmt set_struct_stmt = tir::Evaluate( + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {sid_array, 0, tir::builtin::kArrData, call->args[i]})); + new_stmts.push_back(LetStmt(sid_array, StackAlloca("array", 1), set_struct_stmt)); + packed_args.push_back(sid_array); + } + } + // Finally, evaluate the packed call and return a sequential statement + new_stmts.push_back(tir::Evaluate(tir::Call(call->dtype, call->op, packed_args))); + return tir::SeqStmt(new_stmts); + } + } + return StmtExprMutator::VisitStmt_(op); + } + + private: + InputMap inputs_; // Store the inputs to the primfunc that don't need to be packed. +}; + +namespace transform { + +Pass LegalizePackedCalls() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + + // Create the + InputMap inputs; + for (auto i : f->params) { + inputs[i] = true; + } + n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body)); + return std::move(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {}); +} +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 0e2e612e3ae8d..8b70817398e4a 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -34,16 +34,6 @@ namespace tvm { namespace tir { -inline PrimExpr ConstInt32(size_t index) { - ICHECK_LE(index, std::numeric_limits::max()); - return make_const(DataType::Int(32), static_cast(index)); -} - -inline PrimExpr StackAlloca(std::string type, size_t num) { - Array args = {StringImm(type), ConstInt32(num)}; - return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args); -} - // Calculate the statistics of packed function. // These information are needed during codegen. class BuiltinLower : public StmtExprMutator { diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index cd91a4b53317d..36eeddb17d89b 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -138,35 +138,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); this->VisitExpr(l->index); - } else if (op->op.same_as(builtin::tvm_call_cpacked())) { - // Recall that the arguments of a tvm_call_cpacked are passed as - // TVMValues. But a TVMValue is only a container, that points to - // a real buffer previously allocated. We need to signal that those - // buffers need to be live at the same time (i.e., cannot be overwritten during the function - // call) - Array args = op->args; - for (auto arg : args) { - const VarNode* var = arg.as(); - if (value_to_alloc_.find(var) != value_to_alloc_.end()) { - auto allocs = value_to_alloc_[var]; - for (const VarNode* alloc : allocs) { - VisitExpr_(alloc); - } - } else { - this->VisitExpr(arg); - } - } - } else if (op->op.same_as(builtin::tvm_struct_set())) { - // If we are using a struct_set built-in, and we are setting - // a DLTensor ArrayData field, let's note down the - // buffers that the TVMValue refers to - const VarNode* var = op->args[0].as(); - const VarNode* alloc = op->args[3].as(); - const int field_id = op->args[2].as()->value; - if (var && alloc && field_id == tir::builtin::kArrData) { - value_to_alloc_[var].push_back(alloc); - } - StmtExprVisitor::VisitExpr_(op); } else { StmtExprVisitor::VisitExpr_(op); } @@ -235,13 +206,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { bool in_thread_env_{false}; // The scope stack. std::vector scope_; - // This is a map to connect TVMValues to real allocations. When we pass parameters - // to a tvm_call_cpacked, the data needs to be wrapped in a TVMValue. The wrapping - // happens through the tvm_struct_set built-in. This map is mapping the variable - // representing the TVMValue to the variable representing the real buffer. The live - // analysis needs to happen on the latter and not on the TVMValue which only acts as - // a container. - std::unordered_map> value_to_alloc_; }; // Verify if the statement can be run safely via inplace fashion @@ -923,11 +887,11 @@ class StoragePlanRewriter : public StmtExprMutator { // symbolic free list, for non constant items. std::list sym_free_list_; // The allocation attach map - std::unordered_map> attach_map_; + std::unordered_map > attach_map_; // The allocation assign map std::unordered_map alloc_map_; // The allocations - std::vector> alloc_vec_; + std::vector > alloc_vec_; // analyzer arith::Analyzer analyzer_; }; @@ -986,7 +950,7 @@ class VectorAllocRewriter : public StmtExprMutator { } // Internal access map - std::unordered_map> acc_map_; + std::unordered_map > acc_map_; // Variables to remap Map var_remap_; // internal analyzer