diff --git a/CMakeLists.txt b/CMakeLists.txt index 8995f9a87fb76..7c355238b8c8d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -305,6 +305,7 @@ tvm_file_glob(GLOB_RECURSE RELAY_PASS_SRCS tvm_file_glob(GLOB RELAY_BACKEND_SRCS src/relay/backend/*.cc src/relay/backend/vm/*.cc + src/relay/backend/aot/*.cc ) tvm_file_glob(GLOB_RECURSE RELAY_IR_SRCS src/relay/ir/*.cc diff --git a/python/tvm/relay/backend/_aot.py b/python/tvm/relay/backend/_aot.py new file mode 100644 index 0000000000000..437cd71c4c359 --- /dev/null +++ b/python/tvm/relay/backend/_aot.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The AOT FFI namespace. +""" +import tvm._ffi + +tvm._ffi._init_api("relay.backend.aot", __name__) diff --git a/python/tvm/relay/backend/aot.py b/python/tvm/relay/backend/aot.py new file mode 100644 index 0000000000000..8e7406c72f32b --- /dev/null +++ b/python/tvm/relay/backend/aot.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""AOT passes""" +from tvm.ir.transform import Pass +from .utils import CallType + +from . import _aot + + +def AOTLowerMain(mod_name: str, config: object, call_type: CallType) -> Pass: + """Lower a Relay main function into an AOT TIR main function. + + Parameters + ---------- + mod_name: str + The name of the module. + config : CompilationConfig + The compilation configuration. + call_type : CallType + The calling convention to use. + + Returns + ------- + Pass + The AOTLowerMain pass. + + """ + return _aot.AOTLowerMain(mod_name, config, call_type.value) diff --git a/python/tvm/relay/backend/utils.py b/python/tvm/relay/backend/utils.py index b8430a9e6b6eb..7289dbbc4af4b 100644 --- a/python/tvm/relay/backend/utils.py +++ b/python/tvm/relay/backend/utils.py @@ -15,6 +15,13 @@ # specific language governing permissions and limitations # under the License. """Utility backend functions.""" +from enum import Enum + + +class CallType(Enum): + Packed = 0 + CPacked = 1 + Unpacked = 2 def _is_valid_modname(mod_name): diff --git a/src/relay/backend/aot/aot_lower_main.cc b/src/relay/backend/aot/aot_lower_main.cc new file mode 100644 index 0000000000000..62005984d2003 --- /dev/null +++ b/src/relay/backend/aot/aot_lower_main.cc @@ -0,0 +1,870 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/aot/aot_lower_main.cc + * \brief Lower the Relay main func into an AOT TIR main func. + */ +#include "./aot_lower_main.h" + +#include +#include + +#include "../../op/call/call.h" +#include "../../op/memory/device_copy.h" +#include "../../op/memory/memory.h" +#include "../../transforms/device_aware_visitors.h" +#include "../name_transforms.h" +#include "../utils.h" + +namespace tvm { +namespace relay { +namespace backend { +namespace aot { + +using StorageMap = + std::unordered_map; + +/*! + * \brief Looks at the expressions in a given function and produces an Expr to + * StorageInfo map by assigning one or more StorageInfos to the expressions that + * require storage. + * + * This pass is leveraged by AOTMainLowerer to perform an initial naive allocation + * for tensors in the Relay main function. The resulting storage map is then lowered + * into TIR allocations by AOTMainLowerer where the allocation can be subsequently + * optimized by later passes (e.g. USMP). + */ +class ExprAllocator : public transform::DeviceAwareExprVisitor { + public: + ExprAllocator() : transform::DeviceAwareExprVisitor(Optional()) {} + + // run the visitor on a global function. + void Run(const Function& func) { VisitExpr(func); } + + std::vector GetReturnSIDs() const { return return_sids_; } + + StorageMap GetStorageMap() const { return expr_storage_map_; } + + using ExprVisitor::VisitExpr_; + + void DeviceAwareVisitExpr_(const CallNode* call_node) final { + Expr func; + Array args; + + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + if (call_lowered_props.lowered_func.defined()) { + func = call_lowered_props.lowered_func; + args = call_lowered_props.arguments; + } else { // Relay functions that have not been lowered and lowered extern functions + func = call_node->op; + args = call_node->args; + if (call_node->op.as()) { // Lowered extern function + ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes."; + } else { // Relay function which has not been lowered yet + ICHECK(call_node->op.as()) + << "Expected the call to be to a lowered primfunc, a lowered extern function or a " + "unlowered Relay function."; + } + } + VisitExpr(func); + CreateStorage(call_node); + for (const Expr& arg : args) { + VisitExpr(arg); + } + AssignReturnSID(GetRef(call_node)); + } + + void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { + if (function_nesting() > 1) { + // Do not recurse into sub functions. + return; + } + if (func_node->HasNonzeroAttr(attr::kPrimitive)) { + // No storage needed for primitive functions + return; + } + for (const auto& param : func_node->params) { + CreateStorage(param.get()); + } + VisitExpr(func_node->body); + } + + void PreVisitLetBinding_(const Var& var, const Expr& value) final { + VisitExpr(value); + StorageInfo si = GetStorage(value); + expr_storage_map_[var] = si; + } + + void VisitExpr_(const ConstantNode* op) final { + CreateStorage(op); + AssignReturnSID(GetRef(op)); + } + + void VisitExpr_(const VarNode* op) final { AssignReturnSID(GetRef(op)); } + + void VisitExpr_(const TupleNode* op) final { + std::vector storage_ids; + std::vector virtual_devices; + std::vector storage_sizes_in_bytes; + Expr expr = GetRef(op); + for (Expr field : op->fields) { + auto sid = GetStorage(field); + storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end()); + virtual_devices.insert(virtual_devices.end(), sid->virtual_devices.begin(), + sid->virtual_devices.end()); + storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(), + sid->storage_sizes_in_bytes.begin(), + sid->storage_sizes_in_bytes.end()); + } + expr_storage_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes); + AssignReturnSID(expr); + } + + void VisitExpr_(const TupleGetItemNode* op) final { + Expr expr = GetRef(op); + auto sids = GetStorage(op->tuple); + ICHECK_LT(static_cast(op->index), sids->storage_ids.size()); + expr_storage_map_[expr] = + StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]}, + {sids->storage_sizes_in_bytes[op->index]}); + AssignReturnSID(expr); + } + + void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "'If' is not supported."; } + + private: + /*! + * \brief Assign the expression's storage IDs as the return storage IDs. + * \note This is called when visiting every expression on the understanding + * that the returned expression will be visited last. + */ + void AssignReturnSID(const Expr& e) { + if (expr_storage_map_.find(e) != expr_storage_map_.end()) { + StorageInfo& sinfo = expr_storage_map_[e]; + return_sids_.clear(); + for (auto sid : sinfo->storage_ids) { + return_sids_.push_back(sid); + } + } + } + + /*! + * \brief Get the necessary storage for the expression. + * \param expr The expression. + * \return The corresponding token. + */ + StorageInfo GetStorage(const Expr& expr) { + // See through "on_device" calls. + Expr true_expr = IgnoreOnDevice(expr); + VisitExpr(true_expr); + auto it = expr_storage_map_.find(true_expr); + ICHECK(it != expr_storage_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " " + << PrettyPrint(true_expr) << " in storage device map"; + return it->second; + } + + /*! + * \brief Create storage for the expression. + */ + void CreateStorage(const ExprNode* op) { + Expr expr = GetRef(op); + return CreateStorage(expr, GetVirtualDevice(expr)); + } + + /*! + * \brief Create storage to hold the result of evaluating \p expr in \p virtual_device. + */ + void CreateStorage(const Expr& expr, const VirtualDevice& virtual_device) { + ICHECK(!virtual_device->IsFullyUnconstrained()) + << "invalid virtual device for expr:" << std::endl + << PrettyPrint(expr); + std::vector storage_ids; + std::vector virtual_devices; + std::vector storage_sizes_in_bytes; + for (const auto& ttype : FlattenTupleType(expr->checked_type())) { + storage_ids.push_back(next_available_sid_++); + virtual_devices.push_back(virtual_device); + storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype->shape, ttype->dtype)); + } + expr_storage_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices), + std::move(storage_sizes_in_bytes)); + } + + /*! \brief Map between Exprs and StorageInfos */ + StorageMap expr_storage_map_; + /*! \brief The next available storage ID to be used */ + int next_available_sid_{0}; + /*! \brief The storage IDs that correspond to return values */ + std::vector return_sids_; +}; + +class AOTMainLowerer : public MixedModeVisitor { + public: + AOTMainLowerer(tvm::CompilationConfig config, CallType call_type) + : config_(config), call_type_(call_type) {} + + IRModule Lower(IRModule mod, String mod_name) { + VLOG_CONTEXT << "AOT"; + IRModule lowered_mod = GetRef(mod.CopyOnWrite()); + + auto lowered_main = lowered_mod->Lookup("main"); + auto lowered_main_func = GetRef(lowered_main.as()); + + // Assign StorageInfo to all the Relay exprs + ExprAllocator expr_allocator; + expr_allocator.Run(lowered_main_func); + expr_storage_map_ = expr_allocator.GetStorageMap(); + + for (auto input : lowered_main_func->params) { + input_vars_.push_back(input); + std::string input_name = SanitizeName(input->name_hint()); + // We don't want the compiler changing input names in the + // event of a sanitization collision. Therefore, enforcing + // the var created to use the input_name strictly. + CreateIOVar(input, input_name, /*use_unique_name = */ false); + } + + // Define the storage allocator ids + for (auto kv : expr_storage_map_) { + for (auto sid : kv.second->storage_ids) { + // The buffer_var is created with storage_scope to be global.workspace to be serviced by + // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor + // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and + // should not be lowered to the stack. For more details please refer to the discussion here: + // https://github.com/apache/tvm/issues/9022 + tir::Var buffer_var(MakeString("sid_", sid), + PointerType(PrimType(DataType::Int(8)), "global.workspace")); + sids_table_[sid] = buffer_var; + } + } + + // Retrieve the return sids + return_sid_ = expr_allocator.GetReturnSIDs(); + // Create output vars for the TIR main func + // If output tensor names were provided use them + if (auto opt = lowered_main->GetAttr>("output_tensor_names")) { + Array output_tensor_names = opt.value(); + Expr output_expr = lowered_main_func->body; + if (output_expr->checked_type()->IsInstance()) { + TupleType output_tuple_type = Downcast(output_expr->checked_type()); + for (unsigned i = 0; i < output_tuple_type->fields.size(); i++) { + // AoT Executor Codegen does not create these names, + // thus should be used as they are provided. + CreateIOVar(output_tuple_type->fields[i], output_tensor_names[i], + /*use_unique_name = */ false); + } + } else { + // AoT Executor Codegen does not create these names, + // thus should be used as they are provided. + CreateIOVar(lowered_main_func->body, output_tensor_names[0], /*use_unique_name = */ false); + } + } else { + // If output tensor names are not provided we will generate output(x) + // where x is a counter to create unique names. + if (lowered_main_func->body->checked_type()->IsInstance()) { + CreateIOVar(lowered_main_func->body, "output"); + } else { + CreateIOVar(lowered_main_func->body, "output", /*use_unique_name = */ false); + } + } + + CollectDeviceVariables(lowered_mod->GetAttr>("device_contexts") + .value_or(Map())); + VisitExpr(lowered_main_func->body); + + // Remove the Relay main and replace it with the lowered TIR version + lowered_mod->Remove(lowered_mod->GetGlobalVar("main")); + auto tir_main_func = CreateMainFunc(mod_name); + lowered_mod->Update(GlobalVar(runtime::symbol::tvm_module_main), tir_main_func); + lowered_mod = tir::transform::RemoveNoOp()(lowered_mod); + return lowered_mod; + } + + void VisitExpr_(const CallNode* call_node) override { + OnDeviceProps on_device_props = GetOnDeviceProps(call_node); + if (on_device_props.body.defined()) { + VisitExpr(on_device_props.body); + return; + } + + DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + + if (device_copy_props.body.defined()) { + // TODO(mbs): device_copy cleaunp + // Suspect treating as no-op is better since already built into the StorageInfo? + LOG(FATAL) << "The AOT executor does not currently support device_copy"; + return; + } + + // At this point we should only see calls of the form call_lowered(@callee, (args...)), + // where @callee can be a PrimFunc we've compiled or an external function supplied via + // some other mechanism. + ICHECK(call_lowered_props.lowered_func.defined()) + << "AOT does not support calling Relay functions. Attempting to call:" << std::endl + << PrettyPrint(GetRef(call_node)); + for (const auto& arg : call_lowered_props.arguments) { + // Evaluate the args + VisitExpr(arg); + } + CreateFuncCall(call_lowered_props, GetRef(call_node)); + } + + void VisitExpr_(const VarNode* op) override { + Expr expr = GetRef(op); + StorageInfo& sinfo = expr_storage_map_[expr]; + + // Let bound vars refer to a value, so these should not be considered "output" vars. + if (let_bound_vars_.find(GetRef(op)) != let_bound_vars_.end()) { + return; + } + + // If the Var node is an output node we need to copy the content of the variable to the output + // It's safe to check the SID here because Var StorageToken are never reallocated + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + auto var_expr = FindExpr(expr); + CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), var_expr[0], + /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); + } + } + + void VisitExpr_(const ConstantNode* op) override { + Expr expr = GetRef(op); + ICHECK(expr_storage_map_.find(expr) != expr_storage_map_.end()) + << "Storage map did not contain constant expr " << PrettyPrint(expr); + StorageInfo& sinfo = expr_storage_map_[expr]; + std::stringstream ss; + ss << "constant_" << constant_map_.size(); + + tir::Var constant(ss.str(), PointerType(PrimType(DataType(op->data->dtype)))); + constant_map_[constant] = op; + auto sid = sinfo->storage_ids[0]; + sids_table_[sid] = constant; + + // If the Constant node is an output node we need to copy the content of the parameter to the + // output. A node can only produce a single output + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), + {tir::StringImm(ss.str())}); + CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), constant, + /* pack_input */ false, sinfo->storage_sizes_in_bytes[0]); + } + } + + void VisitExpr_(const TupleNode* op) override { + for (auto field : op->fields) { + VisitExpr(field); + } + } + + void VisitExpr_(const LetNode* op) override { + auto pre_visit = [this](const LetNode* op) { + let_bound_vars_.insert(op->var); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + this->VisitExpr(op->body); + this->visit_counter_[op] += 1; + }; + ExpandANormalForm(op, pre_visit, post_visit); + } + + void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } + void VisitExpr_(const OpNode* op) override { + if (GetRef(op) != CallLoweredOp() && GetRef(op) != OnDeviceOp()) { + LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded"; + } + } + void VisitExpr_(const IfNode* op) override { + LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called"; + } + void VisitExpr_(const FunctionNode* op) override { + ICHECK(op->GetAttr(attr::kCompiler).defined()) + << "FunctionNode only supported by custom codegen"; + } + void VisitExpr_(const RefCreateNode* op) override { + LOG(FATAL) << "AOT executor does not support references (found RefCreateNode)"; + } + void VisitExpr_(const RefReadNode* op) override { + LOG(FATAL) << "AOT executor does not support references (found RefReadNode)"; + } + void VisitExpr_(const RefWriteNode* op) override { + LOG(FATAL) << "AOT executor does not support references (found RefWriteNode)"; + } + void VisitExpr_(const ConstructorNode* op) override { + LOG(FATAL) << "AOT executor does not support ADTs (found ConstructorNode)"; + } + void VisitExpr_(const MatchNode* op) override { + LOG(FATAL) << "AOT executor does not support matching (found MatchNode)"; + } + + private: + /*! + * \brief Create the main PrimFunc to execute the graph. + * \note The packed function calls don't pack their arguments. The AOT + * runner function needs to be legalized by the LegalizePackedCalls pass. + */ + tir::PrimFunc CreateMainFunc(String mod_name) { + tir::Stmt body = tir::SeqStmt(stmts_); + // Allocate the sids + std::unordered_map allocated; + std::vector> sids_to_allocate; + + for (auto kv : expr_storage_map_) { + // Only allocate sids that are needed + const bool is_input = + (std::find(input_vars_.begin(), input_vars_.end(), kv.first) != input_vars_.end()); + if (is_input) { + continue; + } + + for (unsigned int i = 0; i < kv.second->storage_ids.size(); i++) { + sids_to_allocate.push_back( + std::make_pair(kv.second->storage_ids[i], kv.second->storage_sizes_in_bytes[i])); + } + } + + // Sort the SID allocation to make output deterministic + std::sort(sids_to_allocate.begin(), sids_to_allocate.end()); + + for (auto p : sids_to_allocate) { + int sid = p.first; + int size = p.second; + + if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) { + continue; + } + + // Make sure it hasn't already been allocated, this can happen + // with let-bound var/value pairs. + if (allocated.find(sid) != allocated.end()) { + continue; + } + + allocated[sid] = constant_map_.count(sids_table_[sid]); + + // TODO(giuseros): we should allocate this once outside the PrimFunc + // so we don't pay the price of allocation for every inference + if (!allocated[sid]) { + PointerType ptype = Downcast(sids_table_[sid]->type_annotation); + DataType element_type = Downcast(ptype->element_type)->dtype; + body = tir::Allocate(sids_table_[sid], element_type, {size}, tir::const_true(), body); + } + allocated[sid] = true; + } + + for (auto kv : constant_map_) { + auto buffer_var = kv.first; + auto dtype = DataType(kv.second->data->dtype); + + int ndim = kv.second->data->ndim; + Array extents; + + for (int i = 0; i < ndim; i++) { + int shape = kv.second->data->shape[i]; + extents.push_back(tir::make_const(DataType::Int(32), shape, Span())); + } + body = tir::AllocateConst(buffer_var, dtype, extents, kv.second->data, body); + } + + // Define the PrimFunc attributes + Map dict_attrs; + String run_func_name = runtime::get_name_mangled(mod_name, runtime::symbol::tvm_module_main); + dict_attrs.Set("global_symbol", run_func_name); + dict_attrs.Set("runner_function", Bool(true)); + dict_attrs.Set(tvm::attr::kTarget, config_->host_target); + Array input_vars = + Array(main_signature_.begin(), main_signature_.begin() + input_vars_.size()); + dict_attrs.Set("input_vars", input_vars); + Array output_vars = + Array(main_signature_.begin() + input_vars_.size(), + main_signature_.begin() + input_vars_.size() + return_sid_.size()); + dict_attrs.Set("output_vars", output_vars); + + tir::Stmt device_activations = GenerateAllDeviceHook("Activate"); + tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate"); + tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); + + // Make the PrimFunc + return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {}, + DictAttrs(dict_attrs)); + } + + /*! + * \brief Collects device context variables for passing to operators + */ + void CollectDeviceVariables(const Map& device_contexts) { + Map target_contexts; + TargetKindAttrMap target_attr_map = tvm::TargetKind::GetAttrMap("use_device_api"); + + for (const auto& it : device_contexts) { + const GlobalVar& global_var = it.first; + const std::string device_context_name = it.second; + + Optional target_kind = tvm::TargetKind::Get(device_context_name); + if (!target_kind || !target_attr_map.count(target_kind.value())) { + return; + } + if (target_attr_map[target_kind.value()]) { + std::string context_name = SanitizeName(device_context_name); + tir::Var device_context_var("device_context_" + context_name, DataType::Handle()); + + auto pair = target_contexts.find(target_kind.value()); + if (pair != target_contexts.end()) { + device_context_var = (*pair).second; + } else { + main_signature_.push_back(device_context_var); + devices_.Set(context_name, device_context_var); + target_contexts.Set(target_kind.value(), device_context_var); + } + + device_contexts_.Set(global_var, device_context_var); + } + } + } + + /*! + * \brief Return a vector of variables that represents the sids for the given Relay Expr + */ + std::vector PackSid(Expr expr) { + std::vector buffer_vars; + + ICHECK(expr_storage_map_.find(expr) != expr_storage_map_.end()) + << "Storage map did not contain constant expr " << PrettyPrint(expr); + StorageInfo& sinfo = expr_storage_map_[expr]; + + // Note that an expression can have multiple sids associated with it + // e.g., returning multiple values from a function + for (auto sid : sinfo->storage_ids) { + // Determine if an sid is an output buffer + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + buffer_vars.push_back(GetBufferVarForIO(input_vars_.size() + output_index)); + continue; + } + + auto sid_value = sids_table_[sid]; + buffer_vars.push_back(sid_value); + } + return buffer_vars; + } + + /*! + * \brief Given an expression return the variable(s) associated with that expression + */ + std::vector FindExpr(Expr arg) { + auto input_iter = std::find(input_vars_.begin(), input_vars_.end(), arg); + if (input_iter != input_vars_.end()) { + // Input variable + int main_index = std::distance(input_vars_.begin(), input_iter); + return {GetBufferVarForIO(main_index)}; + } else { + // Storage identifier (i.e., intermediate memory) + return PackSid(arg); + } + } + + void PushArgs(const Expr& expr, const std::vector& sids, Array* args) { + const TupleNode* t = expr.as(); + if (t != nullptr) { + CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 into TIR; AOT can't " + "handle this type of Relay Expr in a CallNode."; + } + + args->insert(args->end(), sids.begin(), sids.end()); + } + + /*! + * \brief Wraps a call_extern with a tvm_check_return annotation if required otherwise + * returns the passed Call + */ + tir::Call AddCheckReturn(tir::Call existing_call) { + Array args = {tir::make_const(DataType::Int(32, 1), 0, Span()), + tir::make_const(DataType::Int(32, 1), -1, Span()), existing_call}; + return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); + } + + /*! + * \brief Create a function call + * \param call_lowered_props The lowered function and the arguments to call it with + * \param result_expr The call we got func and args from (so as to recover the storage + * ids to hold the result). + */ + void CreateFuncCall(CallLoweredProps call_lowered_props, const Expr& result_expr) { + std::string func_name = call_lowered_props.lowered_func->name_hint; + tvm::Array args{tvm::tir::StringImm(func_name)}; + std::vector create_func_call_stmts; + + // Pack the inputs + for (const Expr& arg : call_lowered_props.arguments) { + auto sids = FindExpr(arg); + PushArgs(arg, sids, &args); + } + + // Pack the return(s) value. A call node can produce multiple outputs + auto result_expr_sid = PackSid(result_expr); + PushArgs(result_expr, result_expr_sid, &args); + + GlobalVar global_var = call_lowered_props.lowered_func; + bool has_c_device_api_context = device_contexts_.count(global_var) != 0; + tir::Var device_context; + tir::Stmt func_call; + + switch (call_type_) { + case CallType::kUnpacked: { + // call_extern calling convention with optional context + if (has_c_device_api_context) { + device_context = device_contexts_.Get(global_var).value(); + args.push_back(device_context); + } + func_call = tir::Evaluate(AddCheckReturn( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args))); + break; + } + case CallType::kCPacked: { + if (has_c_device_api_context) { + device_context = device_contexts_.Get(global_var).value(); + args.push_back(device_context); + } else { + // NOTE: LowerTVMBuiltin expects some device_context placeholder. + args.push_back(tir::make_zero(DataType::Handle())); + } + func_call = tir::Evaluate( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args)); + create_func_call_stmts.push_back(func_call); + break; + } + case CallType::kPacked: { + // call_packed does not accept a device context. + CHECK(!has_c_device_api_context) << "CallType::kPacked does not accept a device context"; + func_call = tir::Evaluate(AddCheckReturn( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args))); + create_func_call_stmts.push_back(func_call); + break; + } + default: + ICHECK(false) << "Unknown CallType: " << call_type_; + } + + ICHECK(func_call.defined()) << "Must define func_call"; + + if (has_c_device_api_context) { + func_call = tir::SeqStmt(Array({ + GenerateDeviceHook(device_context, "Open"), + func_call, + GenerateDeviceHook(device_context, "Close"), + })); + } + + tir::Stmt body = tir::SeqStmt({func_call}); + stmts_.push_back(body); + } + + /*! + * \brief Copy a variable to the output. This function is mainly used in edge cases + * when we want to return an input or a parameter. + * TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a + * copy-on-write fashion. + */ + void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) { + // Define intermediate DLTensor to load/store the data + tir::Buffer tmp_read = + tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read"); + tir::Buffer tmp_write = + tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write"); + te::Var loop_idx("i", DataType::Int(32)); + auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); + // Copy the variable from the input to the output + tir::Stmt copy = tir::For( + loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, + tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); + stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); + } + + /*! + * \brief Generates a call to a given hook for all Devices found for C Device API + * \param hook Name of hook to generate statements for + * \return Statement with function calls for each device + */ + tir::Stmt GenerateAllDeviceHook(const String& hook) { + std::vector device_hooks; + for (const auto& it : devices_) { + const String& device_name = it.first; + const tir::Var& context = it.second; + Array sections = {"Device", device_name, hook}; + String device_hook_name = ToCFunctionStyle(PrefixName(sections)); + + tir::Evaluate device_hook( + AddCheckReturn(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), + {tvm::tir::StringImm(device_hook_name), context}))); + device_hooks.push_back(device_hook); + } + return tir::SeqStmt(device_hooks); + } + + /*! + * \brief Generates a call to a given hook for a single Device function + * \param context Device context to call hook on + * \param hook Name of hook to generate statements for + * \return Statement with function call to Device API + */ + tir::Stmt GenerateDeviceHook(const tir::Var& context, const String& hook) { + const auto& it = std::find_if(std::begin(devices_), std::end(devices_), [&](const auto& it) { + return it.second->name_hint == context->name_hint; + }); + const String& device_name = (*it).first; + Array sections = {"Device", device_name, hook}; + String device_hook = ToCFunctionStyle(PrefixName(sections)); + + return tir::Evaluate( + AddCheckReturn(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), + {tvm::tir::StringImm(device_hook), context}))); + } + + /*! + * \brief Utility function to string together different arguments + */ + template + std::string MakeString(Args const&... args) { + std::ostringstream ss; + using List = int[]; + (void)List{0, ((void)(ss << args), 0)...}; + + return ss.str(); + } + + /*! + * \brief Access IO vars using the buffer vars and + * not the actual var. + */ + tir::Var GetBufferVarForIO(int index) { return main_buffer_map_[main_signature_[index]]->data; } + + /*! + * \brief Create tir::Var for input/output while updating the buffer_maps. + * \param expr The expression to evaluate. + * \param original_name The name of the tir::Var. + * \param use_unique_name Whether to generate a new unique name where a name conflicts. + */ + void CreateIOVar(const Expr& expr, const std::string& original_name, + bool use_unique_name = true) { + CreateIOVar(expr->checked_type(), original_name, use_unique_name); + } + + /*! + * \brief Create tir::Var for input/output while updating the buffer_maps. + * \param expr The expression to evaluate. + * \param original_name The name of the tir::Var. + * \param use_unique_name Whether to generate a new unique name where a name conflicts. + */ + void CreateIOVar(const Type& type, const std::string& original_name, + bool use_unique_name = true) { + if (type->IsInstance()) { + TupleType tuple_type = Downcast(type); + for (unsigned i = 0; i < tuple_type->fields.size(); i++) { + CreateIOVar(tuple_type->fields[i], original_name); + } + } else { + std::string name = original_name; + if (use_unique_name) { + name = GetUniqueIOVarName(original_name); + } + tir::Var var = tir::Var(name, DataType::Handle()); + main_signature_.push_back(var); + auto tensor_type = type.as(); + ICHECK(tensor_type) << "Expected TensorType node but was " << type->GetTypeKey(); + DataType elem_type = tensor_type->dtype; + tir::Var buffer_var = + tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type), "global")); + tir::Buffer buffer = tir::Buffer(buffer_var, elem_type, tensor_type->shape, {}, 0, + name + "_buffer", 16, 1, tir::BufferType::kDefault); + main_buffer_map_.Set(var, buffer); + } + } + + /*! + * \brief Create a unique name for I/O Var + */ + std::string GetUniqueIOVarName(std::string name) { + if (io_var_names_.find(name) == io_var_names_.end()) { + io_var_names_[name] = 1; + return name + std::to_string(io_var_names_[name] - 1); + } else { + io_var_names_[name] = io_var_names_[name] + 1; + return name + std::to_string(io_var_names_[name] - 1); + } + } + + /*! \brief list of input expressions (i.e., variable passed by the user) */ + std::vector input_vars_; + /*! \brief map of device contexts variables */ + Map devices_; + /*! \brief map of GlobalVars to C Device API contexts */ + Map device_contexts_; + /*! \brief input and output variables belonging to the main function signature */ + Array main_signature_; + /*! \brief input and output variables belonging to the main function signature */ + Map main_buffer_map_; + /*! \brief All available targets. */ + CompilationConfig config_; + /*! + * \brief The type of kernel call to be emitted. + * See CallType for more documentation. + */ + CallType call_type_; + std::unordered_map + constant_map_; + /*! \brief plan memory of device result */ + StorageMap expr_storage_map_; + /*! \brief mapping sid -> tir::Var */ + std::unordered_map sids_table_; + /*! \brief the set of statements that make the program */ + std::vector stmts_; + /*! \brief the list of return sids (note that the function might return more then one output */ + std::vector return_sid_; + /*! \brief This is per IO var name counter to aid the generating unique names */ + std::unordered_map io_var_names_; + /*! \brief A set of variables that are let bound. */ + std::unordered_set let_bound_vars_; +}; + +Pass AOTLowerMain(String mod_name, tvm::CompilationConfig config, CallType call_type) { + runtime::TypedPackedFunc pass_func = + [=](IRModule module, transform::PassContext ctx) { + return AOTMainLowerer(config, call_type).Lower(module, mod_name); + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "AOTLowerMain", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay.backend.aot.AOTLowerMain") + .set_body_typed([](const String& mod_name, const tvm::CompilationConfig& config, + int call_type) { + return AOTLowerMain(mod_name, config, static_cast(call_type)); + }); + +} // namespace aot +} // namespace backend +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/aot/aot_lower_main.h b/src/relay/backend/aot/aot_lower_main.h new file mode 100644 index 0000000000000..2127ad00cbff0 --- /dev/null +++ b/src/relay/backend/aot/aot_lower_main.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAY_BACKEND_AOT_AOT_LOWER_MAIN_H_ +#define TVM_RELAY_BACKEND_AOT_AOT_LOWER_MAIN_H_ + +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace relay { +namespace backend { +namespace aot { + +/*! \brief Lower the Relay main function into TIR for use with the AOT executor. + * + * This pass expects that all operators have already been lowered to TIR and + * so only Calls to 'call_lowered' are present in main. + * + * \param mod_name The name of the module. + * \param config The compilation config. + * \param call_type The call type to use when calling functions. + */ +transform::Pass AOTLowerMain(String mod_name, tvm::CompilationConfig config, CallType call_type); + +} // namespace aot +} // namespace backend +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_AOT_AOT_LOWER_MAIN_H_ diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 5cf7a5563d195..51bcab527d1b7 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -138,8 +138,20 @@ TVM_REGISTER_GLOBAL("relay.ir.StaticMemoryPlan") return StaticMemoryPlan(expr_to_storage_info); }); -// TODO(mbs): Cf GetMemorySizeBytes in aot_executor_codegen.cc, GetMemorySize in -// graph_plan_memory.cc +size_t DivRoundUp(size_t size, size_t word_size) { return (size + word_size - 1) / word_size; } + +size_t GetMemorySizeBytes(const Array& shape, const DataType& dtype) { + size_t size = 1; + for (IndexExpr dim : shape) { + const int64_t* pval = tir::as_const_int(dim); + ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << shape; + ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval; + size *= static_cast(pval[0]); + } + size *= DivRoundUp(dtype.bits() * dtype.lanes(), 8); + return size; +} + int64_t CalculateRelayExprSizeBytes(const Type& expr_type) { if (expr_type->IsInstance()) { auto tuple_type = Downcast(expr_type); @@ -152,17 +164,7 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type) { auto tensor_type = expr_type.as(); ICHECK(tensor_type); auto shape = tensor_type->shape; - int num_of_elements = 1; - for (const auto& dim_index_expr : shape) { - if (dim_index_expr->IsInstance()) { - num_of_elements *= dim_index_expr.as()->value; - } else { - // If shape is dynamic, we cannot calculate workspace in compile time. - num_of_elements = 0; - } - } - auto element_size = tensor_type->dtype.bytes(); - return element_size * num_of_elements; + return GetMemorySizeBytes(tensor_type->shape, tensor_type->dtype); } TVM_REGISTER_NODE_TYPE(FunctionInfoNode); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 37ae9d803a352..6c65a081f156d 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -59,6 +59,73 @@ class TECompiler; namespace backend { using Pass = tvm::transform::Pass; +/*! \brief Describes the type of kernel call emitted. */ +enum CallType { + /*! + * \brief Emit PackedFunc calls bound just-in-time using TVMBackend* functions. + * + * When this type is selected, assumes all operators must be called via TVMFuncCall. Given the + * implementation of TVMFuncCall in the C++ runtime, this in practice implies that those + * functions are of type TVMBackendPackedCFunc. + * + * The following code is emitted at call sites to call a function named `func`: + * void* func_ptr = TVMBackendGetFuncFromEnv("func"); + * TVMFuncCall(func_ptr, values, tcodes, num_args, ret_values, ret_tcodes) + * + * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` + * by LowerTVMBuiltin TIR transform. + * + * If `resource_handle` is passed to `func`, it is determined by TVMFuncCall (often, + * `resource_handle` is registered with the C++ runtime to provide a `this` equivalent when + * `func` is implemented in C). + * + * Compatible with both C++ and C runtimes, implemented with the C runtime only. + */ + kPacked, // Emit tir.call_packed and wrap all arguments in DLTensor. + + /*! + * \brief Directly call a TVMBackendPackedCFunc named according to the tir::Call. + * + * When this type is selected, assumes all operators are implemented in functions of type + * `TVMBackendPackedCFunc` and should be called directly. That is, presumes at the time of + * downstream compilation that there is a symbol named after the 0th arg to tir::Call of + * type `TVMBackendPackedCFunc`. This situation should occur when target_host == target. + * + * The following code is emitted at call sites to call a function named `func`: + * func(values, tcodes, num_args, ret_values, ret_tcodes, resource_handle) + * + * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` + * by LowerTVMBuiltin TIR transform. + * + * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is + * always the device context parameter when not null. At present, the implementation does not + * support forwarding device context parameters to CPacked. + * + * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented + * in the same scenarios. + */ + kCPacked, // Emit tir.call_cpacked and wrap all arguments in DLTensor. + + /*! \brief Directly call a function accepting the `data` arrays as args. + * + * When this type is selected, assumes all operaotrs are implemented in C functions whose + * arguments are 1-to-1 with those in the tir::Call. DLTensor arguments are encoded as just the + * `data` parameters (i.e. no DLTensor object is passed along). + * + * The following code is emitted at call sites to a function named `func`: + * func(void* arg0, void* arg1, ..., void* argN) // no resource_handle + * -or- + * func(void* arg0, void* arg1, ..., void* argN, void* resource_handle) // with resource_handle + * + * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is + * always the device context parameter when not null. + * + * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented + * with the C runtime only. + */ + kUnpacked, // Emit tir.call_extern passing only the `data` part of DLTensors. +}; + /*! * \brief Structure that can be optionally used by the executor codegen */ @@ -207,6 +274,13 @@ class FunctionInfo : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FunctionInfo, ObjectRef, FunctionInfoNode); }; +/*! + * \brief Calculate the bytes of memory needed to hold a tensor of a given shape and data type. + * \param shape The shape of the tensor + * \param dtype The data type of the tensor + */ +size_t GetMemorySizeBytes(const Array& shape, const DataType& dtype); + /*! * \brief Calculate the storage required to store the type of relay.Expr * diff --git a/tests/python/relay/aot/test_pass_aot_lower_main.py b/tests/python/relay/aot/test_pass_aot_lower_main.py new file mode 100644 index 0000000000000..c583b287727a6 --- /dev/null +++ b/tests/python/relay/aot/test_pass_aot_lower_main.py @@ -0,0 +1,429 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long,missing-class-docstring,missing-module-docstring,missing-function-docstring,no-self-argument,unused-argument,invalid-name +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm.script import tir as T +from tvm.relay.backend.aot import AOTLowerMain, CallType + + +def _make_const(dtype, shape): + return tvm.relay.const(np.zeros(shape).astype(dtype)) + + +def _make_consts(dtype, shapes): + return [_make_const(dtype, shape) for shape in shapes] + + +def _plan_devices(mod): + host_target = tvm.target.Target("llvm") + prim_target = tvm.target.Target("llvm", host=host_target) + ctxt = tvm.transform.PassContext() + config = tvm.target.make_compilation_config(ctxt, prim_target) + mod = tvm.relay.transform.PlanDevices(config)(mod) + mod = tvm.relay.transform.InferType()(mod) + return mod, config + + +def _assert_lowered_main(mod, main_func, call_type, print_script=False): + mod, config = _plan_devices(mod) + mod = AOTLowerMain("test_mod", config, call_type)(mod) + if print_script: + print(mod["__tvm_main__"].script()) + + assert mod["__tvm_main__"].script() == main_func.script() + + +def test_single_call_cpacked(): + mod = tvm.parser.parse( + """ +#[version = "0.0.5"] +def @test_fused_add(%x: Tensor[(5, 7), float32]) { %x } + +def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = (%a,) /* ty=(Tensor[(5, 7), float32],) */; + call_lowered(@test_fused_add, %0) /* ty=Tensor[(5, 7), float32] */ +} + """, + ) + + # fmt: off + @T.prim_func + def func(a: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) + # body + T.evaluate(T.tvm_call_cpacked("test_fused_add", a_buffer.data, output_buffer.data, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + # fmt: on + + _assert_lowered_main(mod, func, CallType.CPacked) + + +def test_single_call_packed(): + mod = tvm.parser.parse( + """ +#[version = "0.0.5"] +def @test_fused_add(%x: Tensor[(5, 7), float32]) { %x } + +def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = (%a,) /* ty=(Tensor[(5, 7), float32],) */; + call_lowered(@test_fused_add, %0) /* ty=Tensor[(5, 7), float32] */ +} + """, + ) + + # fmt: off + @T.prim_func + def func(a: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) + # body + T.evaluate(T.tvm_check_return(0, -1, T.tvm_call_packed("test_fused_add", a_buffer.data, output_buffer.data, dtype="int32"), dtype="int32")) + # fmt: on + + _assert_lowered_main(mod, func, CallType.Packed) + + +def test_single_call_unpacked(): + mod = tvm.parser.parse( + """ +#[version = "0.0.5"] +def @test_fused_add(%x: Tensor[(5, 7), float32]) { %x } + +def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = (%a,) /* ty=(Tensor[(5, 7), float32],) */; + call_lowered(@test_fused_add, %0) /* ty=Tensor[(5, 7), float32] */ +} + """, + ) + + # fmt: off + @T.prim_func + def func(a: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) + # body + T.evaluate(T.tvm_check_return(0, -1, T.call_extern("test_fused_add", a_buffer.data, output_buffer.data, dtype="int32"), dtype="int32")) + # fmt: on + + _assert_lowered_main(mod, func, CallType.Unpacked) + + +def test_constant(): + mod = tvm.parser.parse( + """ +#[version = "0.0.5"] +def @test_fused_add(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { %x } + +def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = (%a, meta[relay.Constant][0]) /* ty=(Tensor[(5, 7), float32], Tensor[(5, 7), float32]) */; + call_lowered(@test_fused_add, %0) /* ty=Tensor[(5, 7), float32] */ +} + """, + init_meta_table={"relay.Constant": _make_consts("float32", [(5, 7)])}, + ) + + # fmt: off + @T.prim_func + def func(a: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "global_symbol": "test_mod___tvm_main__", "input_vars": [a], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) + # body + constant_0 = T.allocate_const([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "float32", [5, 7]) + T.evaluate(T.tvm_call_cpacked("test_fused_add", a_buffer.data, constant_0, output_buffer.data, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + # fmt: on + + _assert_lowered_main(mod, func, CallType.CPacked) + + +# TODO(@mbaret) There seems to be a TVMScript round-trip bug causing this to fail +@pytest.mark.xfail() +def test_copy_to_output(): + mod = tvm.parser.parse( + """ +#[version = "0.0.5"] +def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %a +} + """, + ) + + # fmt: off + @T.prim_func + def func(a: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output]}) + tmp_read = T.buffer_var("uint8", "") + # buffer definition + tmp_read_1 = T.buffer_decl([T.uint64(140)], dtype="uint8", data=tmp_read) + a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) + # body + tmp_write: T.Ptr[T.uint8] = output_buffer.data + tmp_write_1 = T.buffer_decl([T.uint64(140)], dtype="uint8", data=tmp_write) + for i in T.serial(140): + tmp_write_1[i] = T.let(tmp_read, a_buffer.data, tmp_read_1[i]) + # fmt: on + + _assert_lowered_main(mod, func, CallType.CPacked) + + +def test_two_calls(): + mod = tvm.parser.parse( + """ +#[version = "0.0.5"] +def @test_fused_add(%x: Tensor[(5, 7), float32]) { %x } + +def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = (%a,) /* ty=(Tensor[(5, 7), float32],) */; + %1 = call_lowered(@test_fused_add, %0) /* ty=Tensor[(5, 7), float32] */; + %2 = (%1,) /* ty=(Tensor[(5, 7), float32],) */; + call_lowered(@test_fused_add, %2) /* ty=Tensor[(5, 7), float32] */ +} + """, + ) + + # fmt: off + @T.prim_func + def func(a: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) + # body + sid_2 = T.allocate([140], "int8", "global.workspace") + T.evaluate(T.tvm_call_cpacked("test_fused_add", a_buffer.data, sid_2, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + T.evaluate(T.tvm_call_cpacked("test_fused_add", sid_2, output_buffer.data, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + # fmt: on + + _assert_lowered_main(mod, func, CallType.CPacked) + + +def test_tuple_output(): + mod = tvm.parser.parse( + """ +#[version = "0.0.5"] +def @test_fused_add(%x: Tensor[(5, 7), float32]) { (%x, %x) } + +def @main(%a: Tensor[(5, 7), float32]) -> (Tensor[(5, 7), float32], Tensor[(5, 7), float32]) { + %0 = (%a,) /* ty=(Tensor[(5, 7), float32],) */; + call_lowered(@test_fused_add, %0) /* ty=(Tensor[(5, 7), float32], Tensor[(5, 7), float32]) */ +} + """, + ) + + # fmt: off + @T.prim_func + def func(a: T.handle, output0: T.handle, output1: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output0, output1]}) + a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) + output0_buffer = T.match_buffer(output0, [5, 7], dtype="float32", align=16) + output1_buffer = T.match_buffer(output1, [5, 7], dtype="float32", align=16) + # body + T.evaluate(T.tvm_call_cpacked("test_fused_add", a_buffer.data, output0_buffer.data, output1_buffer.data, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + # fmt: on + + _assert_lowered_main(mod, func, CallType.CPacked) + + +def test_tuple_intermediate(): + mod = tvm.parser.parse( + """ +#[version = "0.0.5"] +def @test_fused_add_0(%x: Tensor[(5, 7), float32]) -> (Tensor[(5, 7), float32], Tensor[(5, 7), float32]) { (%x, %x) } +def @test_fused_add_1(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { %x } + +def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = (%a,); + %1 = call_lowered(@test_fused_add_0, %0); + %2 = (%1.0, %1.1); + call_lowered(@test_fused_add_1, %2) +} + """, + ) + + # fmt: off + @T.prim_func + def func(a: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) + # body + sid_3 = T.allocate([140], "int8", "global.workspace") + sid_2 = T.allocate([140], "int8", "global.workspace") + T.evaluate(T.tvm_call_cpacked("test_fused_add_0", a_buffer.data, sid_2, sid_3, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + T.evaluate(T.tvm_call_cpacked("test_fused_add_1", sid_2, sid_3, output_buffer.data, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + # fmt: on + + _assert_lowered_main(mod, func, CallType.CPacked) + + +def test_multi_input(): + mod = tvm.parser.parse( + """ +#[version = "0.0.5"] +def @test_fused_add(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { %x } + +def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = (%a, %b) /* ty=(Tensor[(5, 7), float32], Tensor[(5, 7), float32]) */; + call_lowered(@test_fused_add, %0) /* ty=Tensor[(5, 7), float32] */ +} + """, + ) + + # fmt: off + @T.prim_func + def func(a: T.handle, b: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a, b], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) + b_buffer = T.match_buffer(b, [5, 7], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) + # body + T.evaluate(T.tvm_call_cpacked("test_fused_add", a_buffer.data, b_buffer.data, output_buffer.data, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + # fmt: on + + _assert_lowered_main(mod, func, CallType.CPacked) + + +def test_let_binding(): + mod = tvm.parser.parse( + """ +#[version = "0.0.5"] +def @test_fused_add(%x: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { %x } + +def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = (%a,); + let %v1 = call_lowered(@test_fused_add, %0); + %v1 +} + """, + ) + + # fmt: off + @T.prim_func + def func(a: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) + # body + T.evaluate(T.tvm_call_cpacked("test_fused_add", a_buffer.data, output_buffer.data, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + # fmt: on + + _assert_lowered_main(mod, func, CallType.CPacked) + + +def test_let_binding_branch(): + mod = tvm.parser.parse( + """ +#[version = "0.0.5"] +def @test_fused_add_0(%x: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { %x } +def @test_fused_add_1(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { %x } + +def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = (%a,); + let %v0 = call_lowered(@test_fused_add_0, %0); + %1 = (%v0,); + let %v1 = call_lowered(@test_fused_add_0, %1); + %2 = (%v1,); + let %v2 = call_lowered(@test_fused_add_0, %2); + %3 = (%v1, %v2); + let %v3 = call_lowered(@test_fused_add_1, %3); + %v3 +} + """, + ) + + # fmt: off + @T.prim_func + def func(a: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) + # body + sid_3 = T.allocate([140], "int8", "global.workspace") + sid_2 = T.allocate([140], "int8", "global.workspace") + sid_1 = T.allocate([140], "int8", "global.workspace") + T.evaluate(T.tvm_call_cpacked("test_fused_add_0", a_buffer.data, sid_1, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + T.evaluate(T.tvm_call_cpacked("test_fused_add_0", sid_1, sid_2, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + T.evaluate(T.tvm_call_cpacked("test_fused_add_0", sid_2, sid_3, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + T.evaluate(T.tvm_call_cpacked("test_fused_add_1", sid_2, sid_3, output_buffer.data, T.reinterpret(T.uint64(0), dtype="handle"), dtype="int32")) + # fmt: on + + _assert_lowered_main(mod, func, CallType.CPacked) + + +def test_device_hooks(): + mod = tvm.parser.parse( + """ +#[version = "0.0.5"] +def @test_fused_add(%x: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { %x } + +def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = (%a,); + %1 = call_lowered(@test_fused_add, %0); + %2 = (%1,); + call_lowered(@test_fused_add, %2) +} + """, + ) + + # fmt: off + @T.prim_func + def func(a: T.handle, output: T.handle, device_context_example_target_hook: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) + # body + T.evaluate(T.tvm_check_return(0, -1, T.call_extern("TVMDeviceExampleTargetHookActivate", device_context_example_target_hook, dtype="int32"), dtype="int32")) + with T.allocate([140], "int8", "global.workspace") as sid_2: + T.evaluate(T.tvm_check_return(0, -1, T.call_extern("TVMDeviceExampleTargetHookOpen", device_context_example_target_hook, dtype="int32"), dtype="int32")) + T.evaluate(T.tvm_call_cpacked("test_fused_add", a_buffer.data, sid_2, device_context_example_target_hook, dtype="int32")) + T.evaluate(T.tvm_check_return(0, -1, T.call_extern("TVMDeviceExampleTargetHookClose", device_context_example_target_hook, dtype="int32"), dtype="int32")) + T.evaluate(T.tvm_check_return(0, -1, T.call_extern("TVMDeviceExampleTargetHookOpen", device_context_example_target_hook, dtype="int32"), dtype="int32")) + T.evaluate(T.tvm_call_cpacked("test_fused_add", sid_2, output_buffer.data, device_context_example_target_hook, dtype="int32")) + T.evaluate(T.tvm_check_return(0, -1, T.call_extern("TVMDeviceExampleTargetHookClose", device_context_example_target_hook, dtype="int32"), dtype="int32")) + T.evaluate(T.tvm_check_return(0, -1, T.call_extern("TVMDeviceExampleTargetHookDeactivate", device_context_example_target_hook, dtype="int32"), dtype="int32")) + # fmt: on + + device_contexts = {} + for gv in mod.get_global_vars(): + device_contexts[gv] = "example_target_hook" + + mod = mod.with_attr("device_contexts", device_contexts) + + _assert_lowered_main(mod, func, CallType.CPacked) + + +if __name__ == "__main__": + tvm.testing.main()