From 79d54548dc7dfb26e64fe9b4ad3b5168f2ac8a2f Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 2 Mar 2021 20:02:57 -0800 Subject: [PATCH] [RELAY] Modify some passes to not stack overflow on many lets. (#7558) * [RELAY] Modify some passes to not stack overflow on many lets. Passes modified: - inline primitives - dead code - lambda lift * one fix * small fix * .at -> [] * fix --- include/tvm/relay/expr_functor.h | 3 +- src/relay/backend/vm/compiler.cc | 15 ++++--- src/relay/backend/vm/inline_primitives.cc | 15 ++++++- src/relay/backend/vm/lambda_lift.cc | 35 +++++++++++------ src/relay/transforms/dead_code.cc | 48 +++++++++++++++++------ 5 files changed, 85 insertions(+), 31 deletions(-) diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index d53658f87f40..e6eec61a7e9d 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -88,7 +88,8 @@ class ExprFunctor { * \return The result of the call */ virtual R VisitExpr(const Expr& n, Args... args) { - ICHECK(n.defined()); + ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may " + "have generated invalid data."; static FType vtable = InitVTable(); return vtable(n, this, std::forward(args)...); } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 0718191a2ff6..251a55f10b72 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -376,11 +376,16 @@ class VMFunctionCompiler : ExprFunctor { CompileMatch(match); } - void VisitExpr_(const LetNode* let_node) { - DLOG(INFO) << PrettyPrint(let_node->value); - this->VisitExpr(let_node->value); - var_register_map_.insert({let_node->var, this->last_register_}); - this->VisitExpr(let_node->body); + void VisitExpr_(const LetNode* l) final { + Expr let_binding = GetRef(l); + const LetNode* let; + while ((let = let_binding.as())) { + VisitExpr(let->value); + var_register_map_.insert({let->var, this->last_register_}); + let_binding = let->body; + } + + VisitExpr(let_binding); } void VisitExpr_(const TupleGetItemNode* get_node) { diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 650df99645e7..eb848eb7a828 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -58,8 +58,19 @@ struct PrimitiveInliner : ExprMutator { explicit PrimitiveInliner(const IRModule& module) : module_(module) {} Expr VisitExpr_(const LetNode* let_node) { - var_map.insert({let_node->var, VisitExpr(let_node->value)}); - return ExprMutator::VisitExpr_(let_node); + auto pre_visit = [this](const LetNode* op) { + var_map.insert({op->var, this->VisitExpr(op->value)}); + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + // Visit body and cache the op + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + this->memo_[expr] = Let(op->var, value, body); + }; + ExpandANormalForm(let_node, pre_visit, post_visit); + return memo_[GetRef(let_node)]; } Expr VisitExpr_(const CallNode* call) { diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index fe9a544a719e..cc530a10188e 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -61,19 +61,30 @@ class LambdaLifter : public ExprMutator { explicit LambdaLifter(const IRModule& module) : module_(module) {} Expr VisitExpr_(const LetNode* let_node) final { - bool is_lambda = false; - if (auto func = let_node->value.as()) { - if (!func->HasNonzeroAttr(attr::kPrimitive)) { - is_lambda = true; - letrec_.push_back(let_node->var); + auto pre_visit = [this](const LetNode* op) { + bool is_lambda = false; + if (auto func = op->value.as()) { + if (!func->HasNonzeroAttr(attr::kPrimitive)) { + is_lambda = true; + this->letrec_.push_back(op->var); + } } - } - auto value = VisitExpr(let_node->value); - if (is_lambda) { - letrec_.pop_back(); - } - auto body = VisitExpr(let_node->body); - return Let(let_node->var, value, body); + Expr value = this->VisitExpr(op->value); + + if (is_lambda) { + this->letrec_.pop_back(); + } + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + // Visit body and cache the op + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + this->memo_[expr] = Let(op->var, value, body); + }; + ExpandANormalForm(let_node, pre_visit, post_visit); + return memo_[GetRef(let_node)]; } Expr VisitExpr_(const CallNode* call_node) final { diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index 2e7c08a684dc..26624e438b8a 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -46,10 +46,16 @@ class FindDef : private ExprVisitor { VarMap expr_map_; void VisitExpr_(const LetNode* l) final { - ICHECK_EQ(expr_map_.count(l->var), 0); - expr_map_[l->var] = l->value; - VisitExpr(l->value); - VisitExpr(l->body); + auto pre_visit = [this](const LetNode* op) { + ICHECK_EQ(expr_map_.count(op->var), 0); + expr_map_[op->var] = op->value; + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + this->VisitExpr(op->body); + this->visit_counter_[op] += 1; + }; + ExpandANormalForm(l, pre_visit, post_visit); } friend CalcDep; @@ -81,12 +87,24 @@ class Eliminator : private ExprMutator { } Expr VisitExpr_(const LetNode* op) final { - Var v = op->var; - if (HasLet(v)) { - return Let(v, VisitExpr(op->value), VisitExpr(op->body)); - } else { - return VisitExpr(op->body); - } + auto pre_visit = [this](const LetNode* op) { + if (HasLet(op->var)) { + Expr value = this->VisitExpr(op->value); + } + }; + auto post_visit = [this](const LetNode* op) { + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + Var v = op->var; + if (HasLet(v)) { + Expr value = this->VisitExpr(op->value); + this->memo_[expr] = Let(v, value, body); + } else { + this->memo_[expr] = body; + } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; } }; @@ -121,7 +139,15 @@ class CalcDep : protected MixedModeVisitor { } } - void VisitExpr_(const LetNode* l) final { VisitExpr(l->body); } + void VisitExpr_(const LetNode* l) final { + Expr let_binding = GetRef(l); + const LetNode* let; + while ((let = let_binding.as())) { + let_binding = let->body; + visit_counter_[l] += 1; + } + VisitExpr(let_binding); + } void VisitExpr_(const VarNode* v) final { Var var = GetRef(v);