diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index b592265c74cd..1fef02557e09 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -556,6 +556,19 @@ TVM_DLL Pass PlanDevices(CompilationConfig config); */ TVM_DLL Pass FlattenAtrousConv(); +/*! + * \brief Annotates the minimum required memory of each primitive function callsite by analyzing + * the liveness of the input/output tensors at each function callsite and calculating the total + * amount of memory these tensors require. This is added as a "used_memory" annotation to the + * function in question as a list of the number of bytes for each callsite. In addition, the + * containing function is annotated with an "io_used_memory" annotation which refers to the total + * memory required for the IO tensors. + * + * Note: This pass does not support dynamic shapes, it is the users responsibility to check this + * pass isn't applied where dynamic shapes may be input. + */ +TVM_DLL Pass AnnotateUsedMemory(); + } // namespace transform /*! diff --git a/src/relay/backend/annotate_used_memory.cc b/src/relay/backend/annotate_used_memory.cc new file mode 100644 index 000000000000..ad370c73ad1e --- /dev/null +++ b/src/relay/backend/annotate_used_memory.cc @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/annotate_used_memory.cc + * \brief Analyzes the used memory at the callsite of primitive functions. + */ + +#include +#include +#include + +#include +#include + +#include "../transforms/device_aware_visitors.h" +#include "../transforms/pass_utils.h" +#include "./liveness_analysis.h" +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace backend { + +/*! + * \brief Annotates the minimum required memory of each primitive function callsite by analyzing + * the liveness of the input/output tensors at each function callsite and calculating the total + * amount of memory these tensors require. This is added as a "used_memory" annotation to the + * function in question as a list of the number of bytes for each callsite. In addition, the + * containing function is annotated with an "io_used_memory" annotation which refers to the total + * memory required for the IO tensors. + * + * Note: This pass does not support dynamic shapes, it is the users responsibility to check this + * pass isn't applied where dynamic shapes may be input. + * + * A simple example: + * + * Before: + * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] { + * let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] { + * nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0]) + * }; + * let %x_1 = %x_0(%input); + * %x_1 + * } + * + * After: + * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] { + * let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=[32]) -> Tensor[(1, 2, + * 2, 4), int8] { + * nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0]) + * }; + * let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input); + * %x_1 + * } + * + * Note that in the simple example above io_used_memory and used_memory are the same since there + * is only one primitive function. + */ +class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { + public: + AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg, + const transform::LivenessAnalysis& lva) + : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {} + + /*! + * \brief Mutates the input function. In addition, an "io_used_memory" annotation is + * added to the input function which refers to the total size required for the IO + * tensors. + */ + Function operator()(const Function& func) { + uint64_t io_used_memory = 0; + + // Inputs + for (const Var& param : func->params) { + Type type = param->checked_type(); + ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory."; + ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes."; + io_used_memory += CalculateRelayExprSizeBytes(type); + } + + // Outputs + Type type = func->body->checked_type(); + ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory."; + ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes."; + io_used_memory += CalculateRelayExprSizeBytes(type); + + Expr new_func_body = VisitExpr(func->body); + Function new_func = WithFields(func, func->params, new_func_body); + return WithAttr(std::move(new_func), "io_used_memory", + tvm::IntImm(tvm::DataType::UInt(64), io_used_memory)); + } + + /*! + * \brief Establish which let bindings have primitive function values. + */ + std::pair PreVisitLetBinding_(const Var& var, const Expr& value) { + if (const auto* func_node = value.as()) { + ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive)) + << "Expect top-level functions to be primitive."; + let_bound_prim_func_.insert(var); + } + return DeviceAwareExprMutator::PreVisitLetBinding_(var, value); + } + + /*! + * \brief Visit let nodes and perform one of two actions depending on their value: + * + * 1. CallNode - Calculate "used_memory" annotation value at the callsite of + * primitive functions. + * + * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the + * previous analysis at the callsite. + * + */ + Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override { + Var let_var = post_let_node->var; + Expr let_value = IgnoreOnDevice(post_let_node->value); + + if (let_value->IsInstance()) { + Call callsite = Downcast(let_value); + if (CheckPrimitiveFunctionCall(callsite)) { + Var call_op = Downcast(callsite->op); + + // Find all the vars that are live at the callsite. This is done by merging the + // in and out varset's and then removing the var that references the primitive + // function itself since we don't want this included in the calculation. + const transform::ControlFlowGraph::NodePtr cfg_node = + control_flow_graph_.let_map.at(GetRef(pre_let_node)); + transform::VarSet live_tensors = liveness_.live_in.at(cfg_node); + const transform::VarSet& live_out = liveness_.live_out.at(cfg_node); + live_tensors.insert(live_out.begin(), live_out.end()); + live_tensors.erase(call_op); + + // Calculate size of live tensors and store to allow annotation when the function + // gets visited. + uint64_t used_memory = 0; + for (const auto& var : live_tensors) { + Type type = var->checked_type(); + ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory."; + ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes."; + used_memory += CalculateRelayExprSizeBytes(type); + } + IntImm annotation(DataType::UInt(64), used_memory); + used_memory_annotations_[call_op].push_back(annotation); + } + } else if (let_value->IsInstance()) { + Function func = Downcast(let_value); + ICHECK(used_memory_annotations_.find(let_var) != used_memory_annotations_.end()) + << "Could not find used_memory value for primitive function bound at " + << let_var->name_hint(); + Array used_memory = used_memory_annotations_[let_var]; + used_memory_annotations_.erase(let_var); + + Function new_func = WithAttr(std::move(func), "used_memory", + Array(used_memory.rbegin(), used_memory.rend())); + return Let(let_var, new_func, post_let_node->body, post_let_node->span); + } + + return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); + } + + private: + /*! + * \brief Check if a call is a primitive function callsite. + */ + bool CheckPrimitiveFunctionCall(const Call& callsite) { + if (const auto* var_node = callsite->op.as()) { + Var var = GetRef(var_node); + if (let_bound_prim_func_.find(var) != let_bound_prim_func_.end()) { + return true; + } + } + return false; + } + + /*! \brief Control flow graph representation of the main function. */ + transform::ControlFlowGraph control_flow_graph_; + /*! \brief Liveness analysis of the main function. */ + transform::LivenessAnalysis liveness_; + /*! \brief Var's that reference primitive functions. */ + std::unordered_set let_bound_prim_func_; + /*! \brief Stores the calculated uint64 used_memory values so they can be annotated on the + * relevant function. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> used_memory_annotations_; +}; + +} // namespace backend + +namespace transform { + +Pass AnnotateUsedMemory() { + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext ctx) { + GlobalVar gv = mod->GetGlobalVar("main"); + Function main_func = Downcast(mod->Lookup("main")); + + // Perform liveness analysis to determine what tensors are 'live' at each functions callsite. + support::Arena arena; + ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, main_func); + UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg); + LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def); + + auto new_main_func = backend::AnnotateUsedMemoryMutator(mod, cfg, lva)(main_func); + if (!new_main_func.same_as(main_func)) { + mod->Update(gv, new_main_func); + } + return mod; + }; + return CreateModulePass(pass_func, 0, "AnnotateUsedMemory", {"ToANormalForm", "InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.AnnotateUsedMemory").set_body_typed(AnnotateUsedMemory); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 5938417128e0..5020e79714b2 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1079,6 +1079,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { } mod = transform::ToANormalForm()(mod); + mod = transform::InferType()(mod); + mod = transform::AnnotateUsedMemory()(mod); IRModule lowered_mod = tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) { diff --git a/src/relay/backend/liveness_analysis.cc b/src/relay/backend/liveness_analysis.cc new file mode 100644 index 000000000000..52db9e6a4c23 --- /dev/null +++ b/src/relay/backend/liveness_analysis.cc @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/liveness_analysis.cc + * \brief Analysis that collects the live variables before and after each node. + * NOTE: the input IR should be in ANF. + */ + +#include "./liveness_analysis.h" + +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { + +using support::Arena; +using VarSet = std::unordered_set; + +ControlFlowGraph ControlFlowGraph::Create(Arena* arena, const Expr& body) { + return Creator().Create(arena, body); +} + +ControlFlowGraph ControlFlowGraph::Creator::Create(Arena* arena, const Expr& body) { + arena_ = arena; + cfg_.entry = BasicBlock::Make(arena); + VisitExpr(body, cfg_.entry); + return std::move(cfg_); +} + +void ControlFlowGraph::Creator::Succ(BasicBlockPtr from, BasicBlockPtr to) { + from->succ.push_back(to); + to->pred.push_back(from); +} + +void ControlFlowGraph::Creator::VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) { + ICHECK(!in_func_) << "nested functions not supported by CFG analysis"; + in_func_ = true; + + // Unwrap the nested function and proceed normally. + if (f->HasNonzeroAttr(attr::kClosure)) { + ICHECK(f->body.as()); + return VisitExpr(Downcast(f->body)->body, parent); + } + + return VisitExpr(f->body, parent); +} + +void ControlFlowGraph::Creator::VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) { + Expr expr = GetRef(let_node); + + while (const LetNode* inner_let_node = expr.as()) { + NodePtr curr_node = Node::Make(arena_, parent, expr); + + ICHECK(!cfg_.let_map.count(expr)); + cfg_.let_map[expr] = curr_node; + cfg_.reverse_post_order.push_back(curr_node); + + // The basic block ends upon reaching control flow, with successor blocks corresponding to the + // control flow branch exprs (true/false in If, and one for each clause in Match). + if (const IfNode* ite = AsIgnoringOnDevice(inner_let_node->value)) { + // Create the basic blocks for each branch and mark them as successors to the current block. + BasicBlockPtr t_block = BasicBlock::Make(arena_); + BasicBlockPtr f_block = BasicBlock::Make(arena_); + Succ(parent, t_block); + Succ(parent, f_block); + + VisitExpr(ite->true_branch, t_block); + VisitExpr(ite->false_branch, f_block); + + // All subsequent bindings (and/or the body expr) will be in a new basic block. + BasicBlockPtr next = BasicBlock::Make(arena_); + Succ(t_block, next); + Succ(f_block, next); + parent = next; + } else if (const MatchNode* match = AsIgnoringOnDevice(inner_let_node->value)) { + // Same as above but one for each pattern. + std::vector clause_blocks; + BasicBlockPtr next = BasicBlock::Make(arena_); + for (const Clause& clause : match->clauses) { + BasicBlockPtr clause_block = BasicBlock::Make(arena_); + Succ(parent, clause_block); + Succ(clause_block, next); + VisitExpr(clause->rhs, clause_block); + } + parent = next; + } + + expr = inner_let_node->body; + } + + VisitExpr(expr, parent); +} + +void ControlFlowGraph::Creator::VisitExpr_(const IfNode* if_node, BasicBlockPtr parent) { + // TODO(@altanh): is there a way of making this work? + LOG(FATAL) << "If expressions should be bound to variables."; +} + +void ControlFlowGraph::Creator::VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent) { + // TODO(@altanh): same as If + LOG(FATAL) << "Match expressions should be bound to variables."; +} + +VarSet VarUseCollector::VisitExpr_(const VarNode* var_node) { return {GetRef(var_node)}; } + +VarSet VarUseCollector::VisitExpr_(const CallNode* call_node) { + VarSet use = VisitExpr(call_node->op); + for (const Expr& arg : call_node->args) { + VarSet arg_use = VisitExpr(arg); + use.insert(arg_use.begin(), arg_use.end()); + } + return use; +} + +VarSet VarUseCollector::VisitExpr_(const TupleNode* tuple_node) { + VarSet use; + for (const Expr& field : tuple_node->fields) { + VarSet field_use = VisitExpr(field); + use.insert(field_use.begin(), field_use.end()); + } + return use; +} + +VarSet VarUseCollector::VisitExpr_(const TupleGetItemNode* get_node) { + return VisitExpr(get_node->tuple); +} + +VarSet VarUseCollector::VisitExpr_(const IfNode* if_node) { return VisitExpr(if_node->cond); } + +VarSet VarUseCollector::VisitExpr_(const MatchNode* match_node) { + return VisitExpr(match_node->data); +} + +UseDefAnalysis UseDefAnalysis::Analyze(const CFG& cfg) { + UseDefAnalysis a; + + // One pass is sufficient. + for (auto it = cfg.reverse_post_order.begin(); it != cfg.reverse_post_order.end(); ++it) { + const CFG::NodePtr& node = *it; + if (const LetNode* let_node = AsIgnoringOnDevice(node->expr)) { + a.use[node] = a.use_collector.VisitExpr(let_node->value); + a.def[node] = let_node->var; + } else { + a.use[node] = a.use_collector.VisitExpr(node->expr); + a.def[node] = Var(); + } + } + + return a; +} + +bool SetEqual(const VarSet& a, const VarSet& b) { + if (a.size() != b.size()) { + return false; + } + for (auto& xa : a) { + if (!b.count(xa)) { + return false; + } + } + return true; +} + +LivenessAnalysis LivenessAnalysis::Analyze(const ControlFlowGraph& cfg, + const UseDefAnalysis& use_def) { + LivenessAnalysis a; + std::list worklist; + + // Initialize worklist to post-order traversal for quick convergence. + worklist.insert(worklist.end(), cfg.reverse_post_order.rbegin(), cfg.reverse_post_order.rend()); + + // See https://lambda.uta.edu/cse5317/notes/node40.html for an overview of the algorithm. + auto visitor = [&](const CFG::NodePtr n) { + VarSet old_in_n = a.live_in[n]; + VarSet old_out_n = a.live_out[n]; + + a.live_in[n] = use_def.use.at(n); + for (const Var& v : a.live_out[n]) { + if (!v.same_as(use_def.def.at(n))) { + a.live_in[n].insert(v); + } + } + + a.live_out[n] = VarSet(); + for (const CFG::NodePtr& s : n->GetSucc()) { + a.live_out[n].insert(a.live_in[s].begin(), a.live_in[s].end()); + } + + if (SetEqual(old_in_n, a.live_in[n]) && SetEqual(old_out_n, a.live_out[n])) { + // No need to update the worklist. + } else { + // Add predecessor nodes back to worklist (no need to add successors, since each node's + // in/out sets are not dependent on its predecessors). + for (const CFG::NodePtr& p : n->GetPred()) { + worklist.push_back(p); + } + } + }; + + while (!worklist.empty()) { + const CFG::NodePtr n = worklist.front(); + worklist.pop_front(); + visitor(n); + } + + return a; +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/liveness_analysis.h b/src/relay/backend/liveness_analysis.h new file mode 100644 index 000000000000..4e9514056b86 --- /dev/null +++ b/src/relay/backend/liveness_analysis.h @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/liveness_analysis.h + * \brief Analysis that collects the live variables before and after each node. + * NOTE: the input IR should be in ANF. + */ + +#ifndef TVM_RELAY_BACKEND_LIVENESS_ANALYSIS_H_ +#define TVM_RELAY_BACKEND_LIVENESS_ANALYSIS_H_ + +#include + +#include +#include +#include + +#include "../../support/arena.h" +#include "../op/memory/device_copy.h" +#include "../transforms/device_aware_visitors.h" +#include "../transforms/let_list.h" + +namespace tvm { +namespace relay { +namespace transform { + +using support::Arena; +using VarSet = std::unordered_set; + +// TODO(@altanh, @mbs, @mbrookhart): we should do a survey of all "*-flow graphs" in the codebase +// to see what can be deduplicated. + +// TODO(@altanh): support Relay Refs once/if they are supported by the VM. + +/*! + * \brief A representation of an input expression (typically a Function) as a directed graph of + * basic blocks, with edges between basic blocks corresponding to control flow branching. + */ +class ControlFlowGraph { + public: + struct Node; + struct BasicBlock; + + using NodePtr = Node*; + using BasicBlockPtr = BasicBlock*; + + /*! + * \brief A chunk of IR that does not have any control flow branching. At this stage in the IR, + * basic blocks correspond to: + * (1) a sequence of nested Let expressions, where each node in the block corresponds to a + * binding and the last node is either the (non-Let) body or a binding that branches + * (e.g. "let %x = if (%c) { true_block } else { false_block }"). + * (2) an atomic expression representing the target expression of a control flow branch, e.g. + * %v and %u in "let %x = if (%c) { %v } else { %u }". + */ + struct BasicBlock { + // The nodes of the basic block. + std::vector nodes; + // The predecessor basic blocks. + std::vector pred; + // The successor basic blocks. + std::vector succ; + + static BasicBlockPtr Make(support::Arena* arena) { return arena->make(); } + }; + + /*! + * \brief Roughly corresponds to a "statement" in the IR, such as an individual binding in a + * basic block or the "return value" of a block. Each node maps to a single corresponding expr in + * the IR, but the converse is not true (e.g. in the case of variables). + */ + struct Node { + /*! \brief The basic block this node belongs to. */ + BasicBlockPtr parent; + /*! \brief The index into the parent basic block where this node is. */ + size_t index; + /*! \brief The expr this node corresponds to. */ + Expr expr; + + /*! \brief Returns whether or not this node is the first one in the parent basic block. */ + bool IsFirst() const { return index == 0; } + + /*! \brief Returns whether or not this node is the last one in the parent basic block. */ + bool IsLast() const { return index == parent->nodes.size() - 1; } + + /*! \brief Returns the predecessor nodes of this node. */ + std::vector GetPred() const { + std::vector pred; + if (IsFirst()) { + for (const BasicBlockPtr& pred_block : parent->pred) { + pred.push_back(pred_block->nodes.back()); + } + } else { + pred.push_back(parent->nodes[index - 1]); + } + return pred; + } + + /*! \brief Returns the successor nodes of this node. */ + std::vector GetSucc() const { + std::vector succ; + if (IsLast()) { + for (const BasicBlockPtr& succ_block : parent->succ) { + succ.push_back(succ_block->nodes.front()); + } + } else { + succ.push_back(parent->nodes[index + 1]); + } + return succ; + } + + /*! \brief Creates a node with the given expr and appends it to the parent basic block. */ + static NodePtr Make(Arena* arena, BasicBlockPtr parent, Expr expr) { + NodePtr n = arena->make(); + n->parent = parent; + n->expr = expr; + n->index = parent->nodes.size(); + parent->nodes.push_back(n); + return n; + } + }; + + /*! \brief The basic block where control flow begins. */ + BasicBlockPtr entry; + + /*! + * \brief Mapping from Let expressions to their corresponding nodes. Note that Let expressions + * are never shared in ANF (unlike vars), so this is an injection. + */ + std::unordered_map let_map; + + /*! \brief The nodes of the CFG in reverse post order. */ + std::vector reverse_post_order; + + /*! \brief Creates and returns the CFG of the given expression. */ + static ControlFlowGraph Create(Arena* arena, const Expr& body); + + private: + class Creator; +}; + +/*! \brief Helper class for building CFGs. */ +class ControlFlowGraph::Creator : private ExprFunctor { + public: + Creator() {} + + ControlFlowGraph Create(Arena* arena, const Expr& body); + + private: + /*! \brief The arena allocator. */ + Arena* arena_; + + /*! \brief The CFG being built. */ + ControlFlowGraph cfg_; + /*! + * \brief Whether or not we are in a function. CFGs do not support nested functions so this is + * used to error out in such a case. + */ + bool in_func_ = false; + + /*! + * \brief Link \p to as a successor block to \p from. + */ + void Succ(BasicBlockPtr from, BasicBlockPtr to); + +#define DEFAULT_CFG(OP) \ + void VisitExpr_(const OP* op, BasicBlockPtr parent) final { \ + NodePtr n = Node::Make(arena_, parent, GetRef(op)); \ + cfg_.reverse_post_order.push_back(n); \ + } + + void VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) final; + void VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) final; + void VisitExpr_(const IfNode* if_node, BasicBlockPtr parent); + void VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent); + + DEFAULT_CFG(VarNode); + DEFAULT_CFG(GlobalVarNode); + DEFAULT_CFG(ConstantNode); + DEFAULT_CFG(CallNode); + DEFAULT_CFG(OpNode); + DEFAULT_CFG(TupleNode); + DEFAULT_CFG(TupleGetItemNode); +}; + +/*! + * \brief Helper class for collecting the variables used/read by an expression. NOTE: for If exprs, + * only the condition is included (not the branches). Similarly, for Match exprs only the value + * being deconstructed is included. + */ +class VarUseCollector : public ExprFunctor { + public: + VarSet VisitExpr_(const VarNode* var_node); + VarSet VisitExpr_(const CallNode* call_node); + VarSet VisitExpr_(const TupleNode* tuple_node); + VarSet VisitExpr_(const TupleGetItemNode* get_node); + VarSet VisitExpr_(const IfNode* if_node); + VarSet VisitExpr_(const MatchNode* match_node); + + VarSet VisitExpr_(const ConstructorNode* cons_node) { return {}; } + VarSet VisitExpr_(const GlobalVarNode* gvar_node) { return {}; } + VarSet VisitExpr_(const ConstantNode* const_node) { return {}; } + VarSet VisitExpr_(const OpNode* op_node) { return {}; } + VarSet VisitExpr_(const FunctionNode* func_node) { return {}; } +}; + +/*! + * \brief Analysis that collects the variables used and defined at each node. + */ +struct UseDefAnalysis { + using CFG = ControlFlowGraph; + + /*! \brief Mapping of node -> variables used/read by node. */ + std::unordered_map use; + + /*! \brief Mapping of node -> variable defined/written by node. */ + std::unordered_map def; + + VarUseCollector use_collector; + + static UseDefAnalysis Analyze(const CFG& cfg); +}; + +/*! \brief Returns whether \p a and \p b are the same set of vars. */ +bool SetEqual(const VarSet& a, const VarSet& b); + +/*! + * \brief Analysis that collects the live variables before and after each node. + */ +struct LivenessAnalysis { + using CFG = ControlFlowGraph; + + /*! \brief Mapping of node -> set of variables live before node. */ + std::unordered_map live_in; + + /*! \brief Mapping of node -> set of variables live after node. */ + std::unordered_map live_out; + + /*! + * \brief Analyze the input \p cfg (using info from \p use_def). + * + * \param cfg The input control flow graph. + * \param use_def Use-def analysis of \p cfg. + * \return LivenessAnalysis + */ + static LivenessAnalysis Analyze(const ControlFlowGraph& cfg, const UseDefAnalysis& use_def); +}; + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_LIVENESS_ANALYSIS_H_ diff --git a/src/relay/backend/vm/manifest_lifetimes.cc b/src/relay/backend/vm/manifest_lifetimes.cc index 3ba129702b52..486e06320345 100644 --- a/src/relay/backend/vm/manifest_lifetimes.cc +++ b/src/relay/backend/vm/manifest_lifetimes.cc @@ -29,398 +29,12 @@ #include "../../op/memory/device_copy.h" #include "../../transforms/device_aware_visitors.h" #include "../../transforms/let_list.h" +#include "../liveness_analysis.h" namespace tvm { namespace relay { namespace transform { -using support::Arena; -using VarSet = std::unordered_set; - -// TODO(@altanh, @mbs, @mbrookhart): we should do a survey of all "*-flow graphs" in the codebase -// to see what can be deduplicated. - -// TODO(@altanh): support Relay Refs once/if they are supported by the VM. - -/*! - * \brief A representation of an input expression (typically a Function) as a directed graph of - * basic blocks, with edges between basic blocks corresponding to control flow branching. - */ -class ControlFlowGraph { - public: - struct Node; - struct BasicBlock; - - using NodePtr = Node*; - using BasicBlockPtr = BasicBlock*; - - /*! - * \brief A chunk of IR that does not have any control flow branching. At this stage in the IR, - * basic blocks correspond to: - * (1) a sequence of nested Let expressions, where each node in the block corresponds to a - * binding and the last node is either the (non-Let) body or a binding that branches - * (e.g. "let %x = if (%c) { true_block } else { false_block }"). - * (2) an atomic expression representing the target expression of a control flow branch, e.g. - * %v and %u in "let %x = if (%c) { %v } else { %u }". - */ - struct BasicBlock { - // The nodes of the basic block. - std::vector nodes; - // The predecessor basic blocks. - std::vector pred; - // The successor basic blocks. - std::vector succ; - - static BasicBlockPtr Make(Arena* arena) { return arena->make(); } - }; - - /*! - * \brief Roughly corresponds to a "statement" in the IR, such as an individual binding in a - * basic block or the "return value" of a block. Each node maps to a single corresponding expr in - * the IR, but the converse is not true (e.g. in the case of variables). - */ - struct Node { - /*! \brief The basic block this node belongs to. */ - BasicBlockPtr parent; - /*! \brief The index into the parent basic block where this node is. */ - size_t index; - /*! \brief The expr this node corresponds to. */ - Expr expr; - - /*! \brief Returns whether or not this node is the first one in the parent basic block. */ - bool IsFirst() const { return index == 0; } - - /*! \brief Returns whether or not this node is the last one in the parent basic block. */ - bool IsLast() const { return index == parent->nodes.size() - 1; } - - /*! \brief Returns the predecessor nodes of this node. */ - std::vector GetPred() const { - std::vector pred; - if (IsFirst()) { - for (const BasicBlockPtr& pred_block : parent->pred) { - pred.push_back(pred_block->nodes.back()); - } - } else { - pred.push_back(parent->nodes[index - 1]); - } - return pred; - } - - /*! \brief Returns the successor nodes of this node. */ - std::vector GetSucc() const { - std::vector succ; - if (IsLast()) { - for (const BasicBlockPtr& succ_block : parent->succ) { - succ.push_back(succ_block->nodes.front()); - } - } else { - succ.push_back(parent->nodes[index + 1]); - } - return succ; - } - - /*! \brief Creates a node with the given expr and appends it to the parent basic block. */ - static NodePtr Make(Arena* arena, BasicBlockPtr parent, Expr expr) { - NodePtr n = arena->make(); - n->parent = parent; - n->expr = expr; - n->index = parent->nodes.size(); - parent->nodes.push_back(n); - return n; - } - }; - - /*! \brief The basic block where control flow begins. */ - BasicBlockPtr entry; - - /*! - * \brief Mapping from Let expressions to their corresponding nodes. Note that Let expressions - * are never shared in ANF (unlike vars), so this is an injection. - */ - std::unordered_map let_map; - - /*! \brief The nodes of the CFG in reverse post order. */ - std::vector reverse_post_order; - - /*! \brief Creates and returns the CFG of the given expression. */ - static ControlFlowGraph Create(Arena* arena, const Expr& body); - - private: - class Creator; -}; - -/*! \brief Helper class for building CFGs. */ -class ControlFlowGraph::Creator : private ExprFunctor { - public: - Creator() {} - - ControlFlowGraph Create(Arena* arena, const Expr& body) { - arena_ = arena; - cfg_.entry = BasicBlock::Make(arena); - VisitExpr(body, cfg_.entry); - return std::move(cfg_); - } - - private: - /*! \brief The arena allocator. */ - Arena* arena_; - - /*! \brief The CFG being built. */ - ControlFlowGraph cfg_; - /*! - * \brief Whether or not we are in a function. CFGs do not support nested functions so this is - * used to error out in such a case. - */ - bool in_func_ = false; - - /*! - * \brief Link \p to as a successor block to \p from. - */ - void Succ(BasicBlockPtr from, BasicBlockPtr to) { - from->succ.push_back(to); - to->pred.push_back(from); - } - -#define DEFAULT_CFG(OP) \ - void VisitExpr_(const OP* op, BasicBlockPtr parent) final { \ - NodePtr n = Node::Make(arena_, parent, GetRef(op)); \ - cfg_.reverse_post_order.push_back(n); \ - } - - void VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) final { - ICHECK(!in_func_) << "nested functions not supported by CFG analysis"; - in_func_ = true; - - // Unwrap the nested function and proceed normally. - if (f->HasNonzeroAttr(attr::kClosure)) { - ICHECK(f->body.as()); - return VisitExpr(Downcast(f->body)->body, parent); - } - - return VisitExpr(f->body, parent); - } - - void VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) final { - Expr expr = GetRef(let_node); - - while (const LetNode* inner_let_node = expr.as()) { - NodePtr curr_node = Node::Make(arena_, parent, expr); - - ICHECK(!cfg_.let_map.count(expr)); - cfg_.let_map[expr] = curr_node; - cfg_.reverse_post_order.push_back(curr_node); - - // The basic block ends upon reaching control flow, with successor blocks corresponding to the - // control flow branch exprs (true/false in If, and one for each clause in Match). - if (const IfNode* ite = AsIgnoringOnDevice(inner_let_node->value)) { - // Create the basic blocks for each branch and mark them as successors to the current block. - BasicBlockPtr t_block = BasicBlock::Make(arena_); - BasicBlockPtr f_block = BasicBlock::Make(arena_); - Succ(parent, t_block); - Succ(parent, f_block); - - VisitExpr(ite->true_branch, t_block); - VisitExpr(ite->false_branch, f_block); - - // All subsequent bindings (and/or the body expr) will be in a new basic block. - BasicBlockPtr next = BasicBlock::Make(arena_); - Succ(t_block, next); - Succ(f_block, next); - parent = next; - } else if (const MatchNode* match = AsIgnoringOnDevice(inner_let_node->value)) { - // Same as above but one for each pattern. - std::vector clause_blocks; - BasicBlockPtr next = BasicBlock::Make(arena_); - for (const Clause& clause : match->clauses) { - BasicBlockPtr clause_block = BasicBlock::Make(arena_); - Succ(parent, clause_block); - Succ(clause_block, next); - VisitExpr(clause->rhs, clause_block); - } - parent = next; - } - - expr = inner_let_node->body; - } - - VisitExpr(expr, parent); - } - - void VisitExpr_(const IfNode* if_node, BasicBlockPtr parent) { - // TODO(@altanh): is there a way of making this work? - LOG(FATAL) << "If expressions should be bound to variables."; - } - - void VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent) { - // TODO(@altanh): same as If - LOG(FATAL) << "Match expressions should be bound to variables."; - } - - DEFAULT_CFG(VarNode); - DEFAULT_CFG(GlobalVarNode); - DEFAULT_CFG(ConstantNode); - DEFAULT_CFG(CallNode); - DEFAULT_CFG(OpNode); - DEFAULT_CFG(TupleNode); - DEFAULT_CFG(TupleGetItemNode); -}; - -ControlFlowGraph ControlFlowGraph::Create(Arena* arena, const Expr& body) { - return Creator().Create(arena, body); -} - -/*! - * \brief Helper class for collecting the variables used/read by an expression. NOTE: for If exprs, - * only the condition is included (not the branches). Similarly, for Match exprs only the value - * being deconstructed is included. - */ -class VarUseCollector : public ExprFunctor { - public: - VarSet VisitExpr_(const VarNode* var_node) { return {GetRef(var_node)}; } - - VarSet VisitExpr_(const CallNode* call_node) { - VarSet use = VisitExpr(call_node->op); - for (const Expr& arg : call_node->args) { - VarSet arg_use = VisitExpr(arg); - use.insert(arg_use.begin(), arg_use.end()); - } - return use; - } - - VarSet VisitExpr_(const TupleNode* tuple_node) { - VarSet use; - for (const Expr& field : tuple_node->fields) { - VarSet field_use = VisitExpr(field); - use.insert(field_use.begin(), field_use.end()); - } - return use; - } - - VarSet VisitExpr_(const TupleGetItemNode* get_node) { return VisitExpr(get_node->tuple); } - - VarSet VisitExpr_(const IfNode* if_node) { return VisitExpr(if_node->cond); } - - VarSet VisitExpr_(const MatchNode* match_node) { return VisitExpr(match_node->data); } - - VarSet VisitExpr_(const ConstructorNode* cons_node) { return {}; } - - VarSet VisitExpr_(const GlobalVarNode* gvar_node) { return {}; } - - VarSet VisitExpr_(const ConstantNode* const_node) { return {}; } - - VarSet VisitExpr_(const OpNode* op_node) { return {}; } -}; - -/*! - * \brief Analysis that collects the variables used and defined at each node. - */ -struct UseDefAnalysis { - using CFG = ControlFlowGraph; - - /*! \brief Mapping of node -> variables used/read by node. */ - std::unordered_map use; - - /*! \brief Mapping of node -> variable defined/written by node. */ - std::unordered_map def; - - VarUseCollector use_collector; - - static UseDefAnalysis Analyze(const CFG& cfg) { - UseDefAnalysis a; - - // One pass is sufficient. - for (auto it = cfg.reverse_post_order.begin(); it != cfg.reverse_post_order.end(); ++it) { - const CFG::NodePtr& node = *it; - if (const LetNode* let_node = AsIgnoringOnDevice(node->expr)) { - a.use[node] = a.use_collector.VisitExpr(let_node->value); - a.def[node] = let_node->var; - } else { - a.use[node] = a.use_collector.VisitExpr(node->expr); - a.def[node] = Var(); - } - } - - return a; - } -}; - -/*! \brief Returns whether \p a and \p b are the same set of vars. */ -bool SetEqual(const VarSet& a, const VarSet& b) { - if (a.size() != b.size()) { - return false; - } - for (auto& xa : a) { - if (!b.count(xa)) { - return false; - } - } - return true; -} - -/*! - * \brief Analysis that collects the live variables before and after each node. - */ -struct LivenessAnalysis { - using CFG = ControlFlowGraph; - - /*! \brief Mapping of node -> set of variables live before node. */ - std::unordered_map live_in; - - /*! \brief Mapping of node -> set of variables live after node. */ - std::unordered_map live_out; - - /*! - * \brief Analyze the input \p cfg (using info from \p use_def). - * - * \param cfg The input control flow graph. - * \param use_def Use-def analysis of \p cfg. - * \return LivenessAnalysis - */ - static LivenessAnalysis Analyze(const ControlFlowGraph& cfg, const UseDefAnalysis& use_def) { - LivenessAnalysis a; - std::list worklist; - - // Initialize worklist to post-order traversal for quick convergence. - worklist.insert(worklist.end(), cfg.reverse_post_order.rbegin(), cfg.reverse_post_order.rend()); - - // See https://lambda.uta.edu/cse5317/notes/node40.html for an overview of the algorithm. - auto visitor = [&](const CFG::NodePtr n) { - VarSet old_in_n = a.live_in[n]; - VarSet old_out_n = a.live_out[n]; - - a.live_in[n] = use_def.use.at(n); - for (const Var& v : a.live_out[n]) { - if (!v.same_as(use_def.def.at(n))) { - a.live_in[n].insert(v); - } - } - - a.live_out[n] = VarSet(); - for (const CFG::NodePtr& s : n->GetSucc()) { - a.live_out[n].insert(a.live_in[s].begin(), a.live_in[s].end()); - } - - if (SetEqual(old_in_n, a.live_in[n]) && SetEqual(old_out_n, a.live_out[n])) { - // No need to update the worklist. - } else { - // Add predecessor nodes back to worklist (no need to add successors, since each node's - // in/out sets are not dependent on its predecessors). - for (const CFG::NodePtr& p : n->GetPred()) { - worklist.push_back(p); - } - } - }; - - while (!worklist.empty()) { - const CFG::NodePtr n = worklist.front(); - worklist.pop_front(); - visitor(n); - } - - return a; - } -}; - /*! * \brief Helper class to insert kills using liveness information. */ diff --git a/tests/python/relay/test_used_memory_annotator.py b/tests/python/relay/test_used_memory_annotator.py new file mode 100644 index 000000000000..e339152294b6 --- /dev/null +++ b/tests/python/relay/test_used_memory_annotator.py @@ -0,0 +1,434 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name + +""" +Testing for the pass that annotates used memory for each primitive +Relay function. +""" + +import pytest + +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprVisitor + + +def AnnotateUsedMemory(): + return relay.transform._ffi_api.AnnotateUsedMemory() + + +class CheckUsedMemoryAnnotation(ExprVisitor): + """ + Check that the annotations on each function in the graph match + what is expected. + """ + + def __init__(self, expected_annotations, expected_io_annotation): + self.expected_annotations = expected_annotations + self.expected_io_annotation = expected_io_annotation + super().__init__() + + def visit_function(self, fn): + if "Primitive" in fn.attrs: + assert ( + "used_memory" in fn.attrs + ), "Primitive function does not have used_memory annotation." + + assert len(self.expected_annotations) > 0, "Not all expected annotations were compared" + + expected_mem = self.expected_annotations.pop(0) + actual_mem = [int(x) for x in fn.attrs["used_memory"]] + assert expected_mem == actual_mem, ( + f"Expected used memory annotation {expected_mem} " + f"did not match actual annotation {actual_mem}" + ) + super().visit_function(fn) + + def __call__(self, fn): + assert ( + fn.attrs["io_used_memory"] == self.expected_io_annotation + ), "Expected IO annotation did not match." + self.visit(fn.body) + + +def _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation): + mod = relay.transform.InferType()(mod) + mod = relay.transform.ToANormalForm()(mod) + mod = relay.transform.InferType()(mod) + mod = AnnotateUsedMemory()(mod) + + CheckUsedMemoryAnnotation(expected_annotations, expected_io_annotation)(mod["main"]) + + +def _create_primitive_function(expr): + func = relay.Function(relay.analysis.free_vars(expr), expr) + func = func.with_attr("Primitive", 1) + return func + + +def test_simple(): + """ + Test simple graph with one primitive function. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + x = relay.nn.max_pool2d(x) + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") + call = relay.Call(get_inner_func(), [ifm]) + mod = tvm.IRModule.from_expr(call) + + expected_annotations = [ + [2 * (1 * 2 * 2 * 4)], + ] + expected_io_annotation = 2 * (1 * 2 * 2 * 4) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_multiple_functions(): + """ + Test a graph with multiple primitive functions. + """ + + def get_inner_func(ifm_shape): + x = relay.var("x", shape=ifm_shape, dtype="int8") + x = relay.nn.max_pool2d(x, pool_size=(2, 2), layout="NHWC") + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 8, 8, 2), dtype="int8") + x = get_inner_func((1, 8, 8, 2)) + x = relay.Call(x, [ifm]) + y = get_inner_func((1, 7, 7, 2)) + y = relay.Call(y, [x]) + z = get_inner_func((1, 6, 6, 2)) + z = relay.Call(z, [y]) + mod = tvm.IRModule.from_expr(z) + + expected_annotations = [ + [(1 * 8 * 8 * 2) + (1 * 7 * 7 * 2)], + [(1 * 7 * 7 * 2) + (1 * 6 * 6 * 2)], + [(1 * 6 * 6 * 2) + (1 * 5 * 5 * 2)], + ] + expected_io_annotation = (1 * 8 * 8 * 2) + (1 * 5 * 5 * 2) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_mixed_data_types(): + """ + Test a graph with a primitive function that has mixed datatypes. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 2), dtype="int16") + x = relay.cast(x, dtype="uint32") + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 2, 2, 2), dtype="int16") + x = get_inner_func() + x = relay.Call(x, [ifm]) + mod = tvm.IRModule.from_expr(x) + + expected_annotations = [ + [(1 * 2 * 2 * 2) * 2 + (1 * 2 * 2 * 2) * 4], + ] + expected_io_annotation = (1 * 2 * 2 * 2) * 2 + (1 * 2 * 2 * 2) * 4 + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_parallel_function_call(): + """ + Test a graph when the results of two functions are concatenated + into a single result. The second function will also have the result + of the first function alive so will be annotated with a larger + "used memory" value. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + x = relay.reshape(x, newshape=(1, 4, 30)) + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 4, 5, 6), dtype="int8") + x = relay.Call(get_inner_func(), [ifm]) + y = relay.Call(get_inner_func(), [ifm]) + z = relay.concatenate([x, y], axis=0) + mod = tvm.IRModule.from_expr(z) + + expected_annotations = [ + [(1 * 4 * 5 * 6) + (1 * 4 * 30)], + # the output tensor from the previous function is also alive + [(1 * 4 * 5 * 6) + (1 * 4 * 30) + (1 * 4 * 30)], + ] + expected_io_annotation = (1 * 4 * 5 * 6) + (1 * 4 * 60) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_many_different_parallel_calls(): + """ + Test a graph that calls many different functions in parallel. + + input + / | \ + prim_func_1 prim_func_2 prim_func_3 + \ | / + prim_func_4 + """ + + def get_inner_func_1(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + x = relay.tanh(x) + x = _create_primitive_function(x) + return x + + def get_inner_func_2(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + x = relay.nn.max_pool2d(x, pool_size=(1, 1), layout="NHWC") + x = _create_primitive_function(x) + return x + + def get_inner_func_3(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + x = relay.abs(x) + x = relay.nn.relu(x) + x = relay.exp(x) + x = _create_primitive_function(x) + return x + + def get_inner_func_4(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + y = relay.var("y", shape=(1, 4, 5, 6), dtype="int8") + z = relay.var("z", shape=(1, 4, 5, 6), dtype="int8") + out = relay.concatenate([x, y, z], axis=3) + out = _create_primitive_function(out) + return out + + ifm = relay.var("input", shape=(1, 4, 5, 6), dtype="int8") + x = relay.Call(get_inner_func_1(), [ifm]) + y = relay.Call(get_inner_func_2(), [ifm]) + z = relay.Call(get_inner_func_3(), [ifm]) + a = relay.Call(get_inner_func_4(), [x, y, z]) + mod = tvm.IRModule.from_expr(a) + + expected_annotations = [ + [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6)], + # output from prim_func_1 is also still alive + [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6)], + # outputs from prim_func_1 and prim_func_2 are also still alive + [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6)], + [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 18)], + ] + expected_io_annotation = (1 * 4 * 5 * 6) + (1 * 4 * 5 * 18) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_nested_branches(): + """ + Tests a graph with branches that also branch. + + input + / \ + / \ + prim_func_1 prim_func_2 + / \ + / \ + prim_func_3 prim_func_4 + """ + + def get_generic_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + x = relay.nn.relu(x) + return _create_primitive_function(x) + + ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") + a = relay.Call(get_generic_inner_func(), [ifm]) + b = relay.Call(get_generic_inner_func(), [ifm]) + c = relay.Call(get_generic_inner_func(), [b]) + d = relay.Call(get_generic_inner_func(), [b]) + out = relay.concatenate([a, c, d], axis=3) + mod = tvm.IRModule.from_expr(out) + + expected_annotations = [ + [(1 * 2 * 2 * 4) + (1 * 2 * 2 * 4)], + # output from prim_func_1 is also still alive + [(1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4)], + # output from prim_func_1 is also still alive + [(1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4)], + # outputs from prim_func_1 and prim_func_3 are also still alive + [(1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4)], + ] + expected_io_annotation = (1 * 2 * 2 * 4) + (1 * 2 * 2 * 12) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_composite_inner_function(): + """ + Tests the typical BYOC use case where a primitive function + contains a composite function. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + x = relay.nn.max_pool2d(x, pool_size=(2, 2), layout="NHWC") + x = relay.Function(relay.analysis.free_vars(x), x) + x = x.with_attr("Composite", "my_composite_func") + + y = relay.var("y", shape=(1, 2, 2, 4), dtype="int8") + z = relay.Call(x, [y]) + return _create_primitive_function(z) + + ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") + x = relay.Call(get_inner_func(), [ifm]) + mod = tvm.IRModule.from_expr(x) + + expected_annotations = [ + [(1 * 2 * 2 * 4) + (1 * 1 * 1 * 4)], + ] + expected_io_annotation = (1 * 2 * 2 * 4) + (1 * 1 * 1 * 4) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_multiple_calls_to_same_function(): + """ + Tests the case when there are multiple calls to the same function. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + x = relay.nn.max_pool2d(x) + x = _create_primitive_function(x) + return x + + inner_func = get_inner_func() + ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") + call1 = relay.Call(inner_func, [ifm]) + call2 = relay.Call(inner_func, [call1]) + mod = tvm.IRModule.from_expr(call2) + + expected_annotations = [[2 * (1 * 2 * 2 * 4), 2 * (1 * 2 * 2 * 4)]] + expected_io_annotation = 2 * (1 * 2 * 2 * 4) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_parallel_calls_to_same_function(): + """ + Test parallel calls to the same function. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + x = relay.nn.max_pool2d(x) + x = _create_primitive_function(x) + return x + + inner_func = get_inner_func() + ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") + call1 = relay.Call(inner_func, [ifm]) + call2 = relay.Call(inner_func, [ifm]) + concat = relay.concatenate([call1, call2], axis=0) + mod = tvm.IRModule.from_expr(concat) + + expected_annotations = [[2 * (1 * 2 * 2 * 4), 3 * (1 * 2 * 2 * 4)]] + expected_io_annotation = 3 * (1 * 2 * 2 * 4) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_parallel_calls_with_non_ifm_input(): + """ + Test a graph that calls many different functions in parallel where + the input is not the input to the function. + + y = f(x) + / | \ + z0 = g0(y) ... zi = gi(y) + \ | / + concat + """ + + def get_inner_func_1(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + x = relay.tanh(x) + x = _create_primitive_function(x) + return x + + def get_inner_func_2(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + x = relay.nn.max_pool2d(x, pool_size=(2, 2)) + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 4, 5, 6), dtype="int8") + y = relay.Call(get_inner_func_1(), [ifm]) + g = get_inner_func_2() + + no_calls = 20 + z = [relay.Call(g, [y]) for _ in range(0, no_calls)] + out = relay.concatenate(z, axis=3) + mod = tvm.IRModule.from_expr(out) + + expected_annotations = [ + [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6)], + [(1 * 4 * 5 * 6) + (1 * 4 * 4 * 5) * i for i in range(1, no_calls + 1)], + ] + expected_io_annotation = (1 * 4 * 5 * 6) + (1 * 4 * 4 * (5 * no_calls)) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_dynamic_io_tensor_not_supported(): + """ + Test to check dynamic IO tensor error. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + x = relay.nn.max_pool2d(x) + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 2, 2, relay.Any()), dtype="int8") + call = relay.Call(get_inner_func(), [ifm]) + mod = tvm.IRModule.from_expr(call) + + err_rgx = r"AnnotateUsedMemory does not support dynamic shapes" + with pytest.raises(tvm.TVMError, match=err_rgx): + _check_used_memory_annotations(mod, [], []) + + +def test_dynamic_callsite_tensor_not_supported(): + """ + Test to check dynamic callsite tensor error. + """ + + def get_inner_func(): + x = relay.var("x", shape=(relay.Any(), 2, 2, 4), dtype="int8") + x = relay.nn.max_pool2d(x) + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") + call = relay.Call(get_inner_func(), [ifm]) + mod = tvm.IRModule.from_expr(call) + + err_rgx = r"AnnotateUsedMemory does not support dynamic shapes" + with pytest.raises(tvm.TVMError, match=err_rgx): + _check_used_memory_annotations(mod, [], [])