From 681f266af91e9b31bedd206695d5a1872194e41a Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Wed, 30 Jun 2021 00:16:29 +0100 Subject: [PATCH] Decoupling AOT from graph memory planner (#8096) * Fix an issue with storage-rewrite pass and packed functions Change-Id: I13888471d4b8927a4012d6a8e749fb7a8935dd77 * Rebasing Change-Id: I7aa12e0217b8a2e1ff2a97a7c5fdda6b7597ae64 * Addressing comments Change-Id: If9f1ee190690f9a810fe41eb1933d736f1eb4ec3 * Add a pass to legalize packed calls Change-Id: I8aa43d3a1b837b03a5cf3c6b32fc760bd78d3436 * Add a unit test for the legalization pass Change-Id: I5b0d75380ff660dd5a0acf5b14fa84bb992fbec4 * rebasing Change-Id: I52ceab5cf6e9b54390cb36c18dbb8e22505d8e18 * Use common StorageInfo Change-Id: Ia8b7de1373f167ca7d0d69a99846d417405bbe48 --- include/tvm/tir/transform.h | 5 + python/tvm/tir/transform/transform.py | 11 + src/relay/backend/aot_executor_codegen.cc | 354 ++++++++++++------ src/tir/transforms/ir_utils.h | 23 ++ src/tir/transforms/legalize_packed_calls.cc | 121 ++++++ src/tir/transforms/lower_tvm_builtin.cc | 10 - tests/python/relay/aot/aot_test_utils.py | 47 ++- tests/python/relay/aot/test_crt_aot.py | 48 +++ .../unittest/test_aot_legalize_packed_call.py | 80 ++++ 9 files changed, 573 insertions(+), 126 deletions(-) create mode 100644 src/tir/transforms/legalize_packed_calls.cc create mode 100644 tests/python/unittest/test_aot_legalize_packed_call.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 2113d58f1ffa..5ee847e2f010 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -418,6 +418,11 @@ TVM_DLL Pass ConvertBlocksToOpaque(); */ TVM_DLL Pass CompactBufferAllocation(); +/*! + * This pass legalizes packed calls by wrapping their arguments into TVMValues + */ +TVM_DLL Pass LegalizePackedCalls(); + /*! * \brief Flatten the multi-dimensional BufferLoad and BufferStore * to single dimensional Load/Store. Also remove Block to diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 8a32a7e6dff0..51330f80afc6 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -451,6 +451,17 @@ def LowerTVMBuiltin(): return _ffi_api.LowerTVMBuiltin() +def LegalizePackedCalls(): + """Legalize packed calls to have its arguments wrapped in TVMValues + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LegalizePackedCalls() + + def LowerIntrin(): """Lower target specific intrinsic calls. diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 93935af70fca..9b495adbdea8 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -46,50 +47,175 @@ namespace backend { using IntegerArray = Array; using TargetsMap = std::unordered_map; +using StorageMap = + std::unordered_map; -class AotReturnSidVisitor : public ExprVisitor { +/** + * This is an on demand allocator for AOT. A new temporary + * (storage allocator identifier) is allocated for each operation. + */ +class AOTOnDemandAllocator : public ExprVisitor { public: - explicit AotReturnSidVisitor(Map> storage_device_map) - : storage_device_map_{storage_device_map}, return_sid_{-1} {} + // run the visitor on a function. + void Run(const Function& func) { + node_device_map_ = CollectDeviceInfo(func); - IntegerArray FindReturnSid(Function func) { - VisitExpr(func->body); - return return_sid_; + for (Expr param : func->params) { + CreateStorage(param.operator->()); + } + + GetStorage(func->body); } - protected: - void AssignReturnSid(Expr e) { - auto iter = storage_device_map_.find(e); - if (iter != storage_device_map_.end()) { - return_sid_ = (*iter).second[0]; + std::vector GetReturnIds() const { return return_ids_; } + + StorageMap GetStorageMap() const { return storage_device_map_; } + + void VisitExpr_(const ConstantNode* op) final { + CreateStorage(op); + AssignReturnSid(GetRef(op)); + } + + void VisitExpr_(const CallNode* op) final { + // create token for the call node. + CreateStorage(op); + for (Expr arg : op->args) { + GetStorage(arg); } + AssignReturnSid(GetRef(op)); } - void VisitExpr_(const ConstantNode* cn) override { - ExprVisitor::VisitExpr_(cn); - AssignReturnSid(GetRef(cn)); + void VisitExpr_(const VarNode* op) final { + ExprVisitor::VisitExpr_(op); + AssignReturnSid(GetRef(op)); } - void VisitExpr_(const VarNode* vn) override { - ExprVisitor::VisitExpr_(vn); - AssignReturnSid(GetRef(vn)); + void VisitExpr_(const FunctionNode* op) final { + // do not recurse into sub function. } - void VisitExpr_(const CallNode* cn) override { - ExprVisitor::VisitExpr_(cn); - AssignReturnSid(GetRef(cn)); + void VisitExpr_(const GlobalVarNode* op) final { + // Do nothing. } - void VisitExpr_(const LetNode* op) override { VisitExpr(op->body); } + void VisitExpr_(const OpNode* op) final { + // Do nothing. + } - void VisitExpr_(const TupleNode* tn) override { - ExprVisitor::VisitExpr_(tn); - AssignReturnSid(GetRef(tn)); + void VisitExpr_(const TupleNode* op) final { + std::vector storage_ids; + std::vector device_types; + std::vector storage_sizes_in_bytes; + Expr expr = GetRef(op); + for (Expr field : op->fields) { + auto sid = GetStorage(field); + storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end()); + device_types.insert(device_types.end(), sid->device_types.begin(), sid->device_types.end()); + storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(), + sid->storage_sizes_in_bytes.begin(), + sid->storage_sizes_in_bytes.end()); + } + storage_device_map_[expr] = StorageInfo(storage_ids, device_types, storage_sizes_in_bytes); + AssignReturnSid(expr); } + void VisitExpr_(const TupleGetItemNode* op) final { + Expr expr = GetRef(op); + auto sids = GetStorage(op->tuple); + ICHECK_LT(static_cast(op->index), sids->storage_ids.size()); + storage_device_map_[expr] = + StorageInfo({sids->storage_ids[op->index]}, {sids->device_types[op->index]}, + {sids->storage_sizes_in_bytes[op->index]}); + AssignReturnSid(expr); + } + + void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } + + void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "let is not supported."; } + private: - Map> storage_device_map_; - IntegerArray return_sid_; + void AssignReturnSid(Expr e) { + if (storage_device_map_.find(e) != storage_device_map_.end()) { + StorageInfo& sinfo = storage_device_map_[e]; + return_ids_.clear(); + for (auto sid : sinfo->storage_ids) { + return_ids_.push_back(sid); + } + } + } + /*! + * \brief ceil(size/word_size) to get number of words. + * \param size The original size. + * \param word_size The element size. + */ + static size_t DivRoundUp(size_t size, size_t word_size) { + return (size + word_size - 1) / word_size; + } + /*! + * \brief Get the memory requirement. + * \param prototype The prototype token. + * \return The required memory size. + */ + size_t GetMemorySizeBytes(const TensorTypeNode* ttype) { + ICHECK(ttype != nullptr); + size_t size = 1; + for (IndexExpr dim : ttype->shape) { + const int64_t* pval = tir::as_const_int(dim); + ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape; + ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval; + size *= static_cast(pval[0]); + } + size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8); + return size; + } + /*! + * \brief Get the necessary storage for the expression. + * \param expr The expression. + * \return The corresponding token. + */ + StorageInfo GetStorage(const Expr& expr) { + this->VisitExpr(expr); + auto it = storage_device_map_.find(expr); + ICHECK(it != storage_device_map_.end()); + return it->second; + } + + /*! + * \brief Create storage for the expression. + * \param expr The expression. + */ + void CreateStorage(const ExprNode* op) { + std::vector storage_ids; + std::vector device_types; + std::vector storage_sizes_in_bytes; + Expr expr = GetRef(op); + int device_type_int = + node_device_map_.count(GetRef(op)) ? node_device_map_[expr]->value : 0; + if (const auto* tuple_type = op->checked_type().as()) { + for (Type t : tuple_type->fields) { + const auto* ttype = t.as(); + ICHECK(ttype); + storage_ids.push_back(next_available_sid_++); + storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); + device_types.push_back(DLDeviceType(device_type_int)); + } + } else { + const auto* ttype = op->checked_type().as(); + ICHECK(ttype); + storage_ids.push_back(next_available_sid_++); + storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); + device_types.push_back(DLDeviceType(device_type_int)); + } + storage_device_map_[expr] = StorageInfo(storage_ids, device_types, storage_sizes_in_bytes); + } + /*! \brief mapping of expression -> storageInfo*/ + StorageMap storage_device_map_; + /*! \brief mapping of expression -> device type*/ + Map node_device_map_; + /*! \brief current id of the temporary allocated*/ + int next_available_sid_{0}; + /*! \brief the set of intermediate tensors that are return variables */ + std::vector return_ids_; }; /*! \brief Code generator for AOT executor */ @@ -120,65 +246,24 @@ class AOTExecutorCodegen : public ExprVisitor { * \brief Return a vector of variables that represents the sids for the given Relay Expr */ std::vector PackSid(Expr expr) { - Array sids = storage_device_map_[expr]; - std::vector sid_vars; + std::vector buffer_vars; + StorageInfo& sinfo = storage_device_map_[expr]; // Note that an expression can have multiple sids associated with it // e.g., returning multiple values from a function - for (const auto& sid : sids[0]) { + for (auto sid : sinfo->storage_ids) { // Determine if an sid is an output buffer - int sid_int = static_cast((sid.as())->value); - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid_int); + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); if (output_iter != return_sid_.end()) { int output_index = std::distance(return_sid_.begin(), output_iter); - sid_vars.push_back(main_signature_[input_vars_.size() + output_index]); + buffer_vars.push_back(main_signature_[input_vars_.size() + output_index]); continue; } - // Pack the sid inside the TVMValue - auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle()); - auto sid_value = sids_table_[sid]; - if (!use_unpacked_api_) { - tvm::PrimExpr set_tensor = - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrData, sid_value}); - stmts_.push_back( - tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor))); - } else { - stmts_.push_back(tir::LetStmt(sid_array, sid_value, tir::Evaluate(0))); - } - - sid_vars.push_back(sid_array); + auto sid_value = sids_table_[sid]; + buffer_vars.push_back(sid_value); } - return sid_vars; - } - - /*! - * \brief Utility function to return a parameter associated with an expression - * \param expr Relay Expression associated with the parameter - * \return Variable that represents the DLTensor associated with the parameters - */ - tir::Var PackParam(Expr expr) { - int param_sid = param_storage_ids_[params_by_expr_[expr]]; - auto param_array = te::Var(MakeString("param_", param_sid, "_array"), DataType::Handle()); - - // Compose the lookup_call using a local stack - Array lookup_call; - // Set the param to the value returned by lookup_call - auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(params_by_expr_[expr])}); - - if (!use_unpacked_api_) { - tvm::PrimExpr set_param_array = - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {param_array, 0, tir::builtin::kArrData, param_handle}); - stmts_.push_back( - tir::LetStmt(param_array, StackAlloca("arg_value", 1), tir::Evaluate(set_param_array))); - } else { - stmts_.push_back(tir::LetStmt(param_array, param_handle, tir::Evaluate(0))); - } - - return param_array; + return buffer_vars; } /*! @@ -190,9 +275,6 @@ class AOTExecutorCodegen : public ExprVisitor { // Input variable int main_index = std::distance(input_vars_.begin(), input_iter); return {main_signature_[main_index]}; - } else if (params_by_expr_.find(arg) != params_by_expr_.end()) { - // Parameter of the network - return {PackParam(arg)}; } else { // Storage identifier (i.e., intermediate memory) return PackSid(arg); @@ -208,8 +290,14 @@ class AOTExecutorCodegen : public ExprVisitor { // Pack the inputs for (Expr arg : call->args) { - auto var_arg = FindExpr(arg); - args.push_back(var_arg[0]); + if (params_by_expr_.find(arg) != params_by_expr_.end()) { + auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), + {tir::StringImm(params_by_expr_[arg])}); + args.push_back(param_handle); + } else { + auto var_arg = FindExpr(arg); + args.push_back(var_arg[0]); + } } auto ret_expr = Downcast(call); @@ -237,7 +325,7 @@ class AOTExecutorCodegen : public ExprVisitor { * TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a * copy-on-write fashion. */ - void CopyToOutput(te::Var out, te::Var in, size_t size) { + void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) { // Define intermediate DLTensor to load/store the data auto tmp0 = te::Var("tmp0", DataType::Handle()); auto tmp1 = te::Var("tmp1", DataType::Handle()); @@ -249,10 +337,15 @@ class AOTExecutorCodegen : public ExprVisitor { PrimExpr tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), {out, 0, tir::builtin::kArrData}); if (use_unpacked_api_) { - retval_get = in; tostore = out; } + // Do not pack the input if the flag is set or the caller + // explicitly asked to do so (e.g., copying a param to the output) + if (use_unpacked_api_ || !pack_input) { + retval_get = in; + } + // Copy the variable from the input to the output tir::Stmt copy = tir::For( loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial, @@ -390,8 +483,8 @@ class AOTExecutorCodegen : public ExprVisitor { } ICHECK_GE(storage_device_map_.count(expr), 0); - auto& device_type = storage_device_map_[expr][1]; - auto call_dev_type = device_type[0]->value; + StorageInfo& sinfo = storage_device_map_[expr]; + auto call_dev_type = sinfo->device_types[0]; // Normal Relay Function if (targets_.size() == 1) { // homogeneous execution. @@ -425,17 +518,23 @@ class AOTExecutorCodegen : public ExprVisitor { void VisitExpr_(const VarNode* op) override { Expr expr = GetRef(op); + StorageInfo& sinfo = storage_device_map_[expr]; // If the Var node is an output node we need to copy the content of the variable to the output // It's safe to check the SID here because Var StorageToken are never reallocated - Array sids = storage_device_map_[expr]; - - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), - static_cast((sids[0][0].as())->value)); + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); if (output_iter != return_sid_.end()) { int output_index = std::distance(return_sid_.begin(), output_iter); - auto var_expr = FindExpr(expr); - CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0], sids[2][0]); + if (params_by_expr_.find(expr) != params_by_expr_.end()) { + auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), + {tir::StringImm(params_by_expr_[expr])}); + CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle, + /*pack_input*/ true, sinfo->storage_sizes_in_bytes[0]); + } else { + auto var_expr = FindExpr(expr); + CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0], + /*pack_input*/ true, sinfo->storage_sizes_in_bytes[0]); + } } } @@ -443,19 +542,20 @@ class AOTExecutorCodegen : public ExprVisitor { Expr expr = GetRef(op); size_t index = params_.size(); std::string name = "p" + std::to_string(index); - - param_storage_ids_[name] = storage_device_map_[expr][0][0]->value; + StorageInfo& sinfo = storage_device_map_[expr]; + param_storage_ids_[name] = sinfo->storage_ids[0]; params_[name] = op->data; params_by_expr_.Set(expr, name); // If the Constant node is an output node we need to copy the content of the parameter to the // output A Var node can only produce a single output - Array sids = storage_device_map_[expr]; - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), - static_cast((sids[0][0].as())->value)); + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); if (output_iter != return_sid_.end()) { int output_index = std::distance(return_sid_.begin(), output_iter); - CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr), sids[2][0]); + auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), + {tir::StringImm(params_by_expr_[expr])}); + CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle, false, + sinfo->storage_sizes_in_bytes[0]); } } @@ -495,7 +595,9 @@ class AOTExecutorCodegen : public ExprVisitor { throw std::invalid_argument("match case not yet implemented"); } - // Create the main PrimFunc to execute the graph + // Create the main PrimFunc to execute the graph. Please note that + // the packed function calls don't pack their arguments. The AOT + // runner function needs to be legalized by the LegalizePackedCalls pass. tir::PrimFunc CreateMainFunc(unsigned int relay_params) { tir::Stmt body = tir::SeqStmt(stmts_); @@ -511,9 +613,9 @@ class AOTExecutorCodegen : public ExprVisitor { continue; } - for (unsigned int i = 0; i < kv.second[0].size(); i++) { - int size = kv.second[2][i]; - int sid = static_cast((kv.second[0][i].as())->value); + for (unsigned int i = 0; i < kv.second->storage_ids.size(); i++) { + int size = kv.second->storage_sizes_in_bytes[i]; + int sid = kv.second->storage_ids[i]; if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) { continue; @@ -523,6 +625,8 @@ class AOTExecutorCodegen : public ExprVisitor { // so we don't pay the price of allocation for every inference if (!allocated[sid]) { body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body); + body = tir::AttrStmt(sids_table_[sid], tir::attr::storage_scope, tir::StringImm("global"), + body); } allocated[sid] = true; } @@ -578,7 +682,8 @@ class AOTExecutorCodegen : public ExprVisitor { std::unordered_map param_storage_ids_; /*! \brief plan memory of device result */ - Map> storage_device_map_; + StorageMap storage_device_map_; + /*! \brief mapping sid -> tir::Var */ std::unordered_map sids_table_; /*! \brief lowered funcs */ std::unordered_map lowered_funcs_; @@ -589,7 +694,7 @@ class AOTExecutorCodegen : public ExprVisitor { /*! \brief the set of statements that make the program */ std::vector stmts_; /*! \brief the list of return sids (note that the function might return more then one output */ - IntegerArray return_sid_; + std::vector return_sid_; /*! \brief the module name we use to mangle the function names */ String mod_name_; @@ -602,9 +707,11 @@ class AOTExecutorCodegen : public ExprVisitor { compile_engine_(CompileEngine::Global()) {} LoweredOutput Codegen(relay::Function func, String mod_name) { - // Get the module, storage map and token sizes - auto pf = GetPackedFunc("relay.backend.GraphPlanMemory"); - storage_device_map_ = (*pf)(func); + auto aot_allocator = AOTOnDemandAllocator(); + aot_allocator.Run(func); + + // Retrieve the storage map + storage_device_map_ = aot_allocator.GetStorageMap(); mod_name_ = mod_name; for (auto input : func->params) { @@ -614,20 +721,23 @@ class AOTExecutorCodegen : public ExprVisitor { // Define the storage allocator ids for (auto kv : storage_device_map_) { - for (const auto& sid : kv.second[0]) { - te::Var sid_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8)))); - sids_table_[sid] = sid_var; + for (auto sid : kv.second->storage_ids) { + te::Var buffer_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8)))); + sids_table_[sid] = buffer_var; } } - // Find the return sid - return_sid_ = AotReturnSidVisitor(storage_device_map_).FindReturnSid(func); + // Retrieve the return sids + return_sid_ = aot_allocator.GetReturnIds(); for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) { main_signature_.push_back(tir::Var("output", DataType::Handle())); } VisitExpr(func->body); + // Create the runner function. Please note that the function is not legal yet + // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need + // to run the LegalizePackedCalls pass. auto prim_func = CreateMainFunc(func->params.size()); UpdateMainWorkspaceSize(prim_func, func); LoweredOutput ret; @@ -649,14 +759,28 @@ class AOTExecutorCodegen : public ExprVisitor { } ret.external_mods = compile_engine_->LowerExternalFunctions(); + // Build the TIR IRModule + Map symbol_map; + symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func); + IRModule mod_run(symbol_map); + + // Apply storage rewrite pass to the runner function to do memory planning + auto storage_rewrite = tir::transform::StorageRewrite(); + mod_run = storage_rewrite(mod_run); + + // Legalize AOT if needed. This means that all the packed calls + // need to be wrapped in TVMValues (unless use_unpacked_api is set) + if (!use_unpacked_api_) { + auto pack_calls = tir::transform::LegalizePackedCalls(); + mod_run = pack_calls(mod_run); + } + + // Update the lowered functions auto target_host_str = target_host_->str(); if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { - ret.lowered_funcs[target_host_str]->Add( - GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func); + ret.lowered_funcs[target_host_str]->Update(mod_run); } else { - Map symbol_map; - symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func); - ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map)); + ret.lowered_funcs.Set(target_host_str, mod_run); } ret.function_metadata = std::move(function_metadata_); ret.metadata = runtime::Metadata(input_vars_.size(), return_sid_.size(), diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 3b4e693b820a..906ff8a38b6c 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -29,6 +29,8 @@ #include #include +#include +#include #include namespace tvm { @@ -161,6 +163,27 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { return align; } +/*! + * \brief Create an int32 constant + * \param index the value of the constant + * \return the PrimExpr that represents the constant + */ +inline PrimExpr ConstInt32(size_t index) { + ICHECK_LE(index, std::numeric_limits::max()); + return make_const(DataType::Int(32), static_cast(index)); +} + +/*! + * \brief Allocate TVMValues on the stack + * \param type type of allocation + * \param num number of TVMValues to allocate + * \return PrimExpr representing the TVMValue + */ +inline PrimExpr StackAlloca(std::string type, size_t num) { + Array args = {StringImm(type), ConstInt32(num)}; + return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args); +} + /*! * \brief Convert a IR node to be SSA form. * \param stmt The source statement to be converted. diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc new file mode 100644 index 000000000000..424da1e817b6 --- /dev/null +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file make_packed_call.cc + * \brief Rewrite packed calls in AOT so that the arguments are packed + */ +#include +#include +#include +#include +#include +#include + +#include + +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +using InputMap = + std::unordered_map; +/** + * This is a legalization pass only used in AOT. Traverse the TIR graph to legalize + * packed calls by making its argument wrapped in TVMValues (by using tvm_set_struct built-in) + */ +class PackedCallLegalizer : public StmtExprMutator { + public: + Stmt Legalize(const InputMap& params, tir::Stmt body) { + inputs_ = params; + return StmtExprMutator::VisitStmt(body); + } + + Stmt VisitStmt_(const EvaluateNode* op) final { + if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op); + const CallNode* call = op->value.as(); + // Given a packed call f(A,B,C), we need a set of new statements + // let A_packed = set_struct(tvm_value1, A) + // let B_packed = set_struct(tvm_value2, B) + // let C_packed = set_struct(tvm_value3, C) + // call_packed(f, A_packed, B_packed, C_packed) + std::vector new_stmts; + if (call) { + if (call->op.same_as(builtin::tvm_call_cpacked())) { + Array packed_args{call->args[0]}; + std::vector tvm_values; + for (unsigned i = 1; i < call->args.size(); i++) { + // No need to pack inputs of the prim_func + if (inputs_[call->args[i]] == true) { + packed_args.push_back(call->args[i]); + } else { + // Pack the argument inside a TVMValue + std::stringstream ss; + ss << "tvm_value_" << tvm_value_index_++; + auto sid_array = tir::Var(ss.str(), DataType::Handle()); + tvm_values.push_back(sid_array); + + new_stmts.push_back(tir::Evaluate( + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {sid_array, 0, tir::builtin::kArrData, call->args[i]}))); + packed_args.push_back(sid_array); + } + } + // Evaluate the packed call + new_stmts.push_back(tir::Evaluate(tir::Call(call->dtype, call->op, packed_args))); + tir::Stmt call_stmt = tir::SeqStmt(new_stmts); + + // Allocate the TVMValues on the stack and define the variables + for (auto v : tvm_values) { + call_stmt = LetStmt(v, StackAlloca("array", 1), call_stmt); + } + return call_stmt; + } + } + return StmtExprMutator::VisitStmt_(op); + } + + private: + InputMap inputs_; // Store the inputs to the primfunc that don't need to be packed. + int tvm_value_index_; // Index of the actual tvm_value variable +}; + +namespace transform { + +Pass LegalizePackedCalls() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + + // Create the + InputMap inputs; + for (auto i : f->params) { + inputs[i] = true; + } + n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body)); + return std::move(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LegalizePackedCalls").set_body_typed(LegalizePackedCalls); +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 0e2e612e3ae8..8b70817398e4 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -34,16 +34,6 @@ namespace tvm { namespace tir { -inline PrimExpr ConstInt32(size_t index) { - ICHECK_LE(index, std::numeric_limits::max()); - return make_const(DataType::Int(32), static_cast(index)); -} - -inline PrimExpr StackAlloca(std::string type, size_t num) { - Array args = {StringImm(type), ConstInt32(num)}; - return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args); -} - // Calculate the statistics of packed function. // These information are needed during codegen. class BuiltinLower : public StmtExprMutator { diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index a18a0fa7dbe7..836ff4b22b20 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -42,6 +42,46 @@ def mangle_name(mod_name, name): return mod_name + "_" + name +def convert_to_relay( + tflite_model_buf, + input_data, + input_node, +): + """Convert a tflite model buffer in a Relay module""" + + def convert_to_list(x): + if not isinstance(x, list): + x = [x] + return x + + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) + except ImportError: + raise ImportError("The tflite package must be installed") + + input_data = convert_to_list(input_data) + input_node = convert_to_list(input_node) + + shape_dict = {} + dtype_dict = {} + for i, e in enumerate(input_node): + shape_dict[e] = input_data[i].shape + dtype_dict[e] = input_data[i].dtype.name + + mod, params = relay.frontend.from_tflite( + tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict + ) + mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) + return mod, params + + def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout): """ This method runs a process and logs the output to both a log file and stdout @@ -221,6 +261,7 @@ def compile_and_run( params=None, workspace_byte_alignment=8, mod_name=None, + enable_op_fusion=True, ): """ This method verifies the generated source @@ -232,7 +273,11 @@ def compile_and_run( if not use_calculated_workspaces: cflags += "-DTVM_CRT_STACK_ALLOCATOR_ENABLE_LIFO_CHECK " - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + config = {"tir.disable_vectorize": True} + if not enable_op_fusion: + config["relay.FuseOps.max_depth"] = 1 + + with tvm.transform.PassContext(opt_level=3, config=config): lib = tvm.relay.build(mod, target, target_host=target, params=params, mod_name=mod_name) tmp_path = utils.tempdir() diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 485267cb03f7..5505c4eb630b 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -466,5 +466,53 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), ) +def test_quant_mobilenet_tfl(): + """Since in AOT we pass directly the output buffer from the user, in quantized networks sharing the output buffers is not possible. + This is because the output data type is int8 and the intermediate buffer are int32 or int16. We use mobilenet quantized to stress this + situation and verify that the output buffer sharing is disabled in AOT.""" + pytest.importorskip("tflite") + + import tvm.relay.testing.tf as tf_testing + + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/" + "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", + "mobilenet_v1_1.0_224_quant.tflite", + ) + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + data_shape = (1, 224, 224, 3) + in_min, in_max = (0, 255) + data = np.random.randint(in_min, high=in_max, size=data_shape, dtype="uint8") + mod, params = convert_to_relay(tflite_model_buf, data, "input") + inputs = {"input": data} + output_list = generate_ref_data(mod, inputs, params) + input_list = [inputs["input"]] + compile_and_run(mod, input_list, output_list, "--unpacked-api=0", True, params) + + +@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) +def test_transpose(target_options): + """Test that non-inpleaceable operations (e.g., transpose) do not happen in-place.""" + + dtype = "float32" + x = relay.var("x", shape=(10, 5), dtype=dtype) + y = relay.var("y", shape=(10, 5), dtype=dtype) + t = relay.var("z", shape=(), dtype=dtype) + a = relay.add(x, y) + b = relay.transpose(a) + z = relay.add(b, t) + # Check result. + func = relay.Function([x, y, t], z) + x_data = np.random.rand(10, 5).astype(dtype) + y_data = np.random.rand(10, 5).astype(dtype) + t_data = np.random.uniform(size=()).astype(dtype) + inputs = {"x": x_data, "y": y_data, "z": t_data} + + output_list = generate_ref_data(func, inputs) + input_list = [inputs["x"], inputs["y"], inputs["z"]] + compile_and_run(func, input_list, output_list, target_options, True, enable_op_fusion=False) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py new file mode 100644 index 000000000000..626af0c96633 --- /dev/null +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import tvm +from tvm.script import ty +from tvm import te, tir +import numpy as np +import tvm.testing +import pytest + + +@tvm.script.tir +class Module: + def tir_packed_call() -> None: + A = tir.var("handle") + B = tir.var("handle") + C = tir.var("handle") + # body + tir.evaluate( + tir.tvm_call_cpacked( + "tvm_test_cpacked", + A, + B, + C, + dtype="int32", + ) + ) + + +@tvm.script.tir +class Expected: + def tir_packed_call() -> None: + A = tir.var("handle") + B = tir.var("handle") + C = tir.var("handle") + + # body + tvm_value_2 = tir.var("handle") + tvm_value_1 = tir.var("handle") + tvm_value_0 = tir.var("handle") + with tir.let(tvm_value_2, tir.tvm_stack_alloca("array", 1, dtype="handle")): + with tir.let(tvm_value_1, tir.tvm_stack_alloca("array", 1, dtype="handle")): + with tir.let(tvm_value_0, tir.tvm_stack_alloca("array", 1, dtype="handle")): + tir.evaluate(tir.tvm_struct_set(tvm_value_0, 0, 1, A, dtype="handle")) + tir.evaluate(tir.tvm_struct_set(tvm_value_1, 0, 1, B, dtype="handle")) + tir.evaluate(tir.tvm_struct_set(tvm_value_2, 0, 1, C, dtype="handle")) + tir.evaluate( + tir.tvm_call_cpacked( + "tvm_test_cpacked", + tvm_value_0, + tvm_value_1, + tvm_value_2, + dtype="int32", + ) + ) + + +def test_aot_packed_call(): + mod = Module() + expected = Expected() + out = tir.transform.LegalizePackedCalls()(mod) + tvm.ir.assert_structural_equal(expected, out, map_free_vars=True) + + +if __name__ == "__main__": + pytest.main([__file__])