From 1495b342056859c000f053c5a464dd3816571d5b Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Fri, 27 Aug 2021 06:04:09 +0100 Subject: [PATCH] Change AOT from ExprVisitor to MixedModeVisitor (#8856) This should allow better scale-ability for AOT when targeting larger networks. --- src/relay/backend/aot_executor_codegen.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 942bc0d1d44a..2fb35f3a2e27 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -53,7 +53,7 @@ using StorageMap = * This is an on demand allocator for AOT. A new temporary * (storage allocator identifier) is allocated for each operation. */ -class AOTOnDemandAllocator : public ExprVisitor { +class AOTOnDemandAllocator : public MixedModeVisitor { public: // run the visitor on a function. void Run(const Function& func) { @@ -84,10 +84,7 @@ class AOTOnDemandAllocator : public ExprVisitor { AssignReturnSid(GetRef(op)); } - void VisitExpr_(const VarNode* op) final { - ExprVisitor::VisitExpr_(op); - AssignReturnSid(GetRef(op)); - } + void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef(op)); } void VisitExpr_(const FunctionNode* op) final { // do not recurse into sub function. @@ -218,7 +215,7 @@ class AOTOnDemandAllocator : public ExprVisitor { }; /*! \brief Code generator for AOT executor */ -class AOTExecutorCodegen : public ExprVisitor { +class AOTExecutorCodegen : public MixedModeVisitor { protected: /*! * \brief Utility function to allocate a DLTensor or TVMValue @@ -437,7 +434,6 @@ class AOTExecutorCodegen : public ExprVisitor { void VisitExpr_(const OpNode* op) override { throw std::runtime_error("can not compile op in non-eta expanded form"); } - void VisitExpr_(const GlobalVarNode* op) override { throw std::runtime_error(""); } void VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); } void VisitExpr_(const FunctionNode* op) override { ICHECK(op->GetAttr(attr::kCompiler).defined())