Skip to content

Commit

Permalink
[RELAY] Modify some passes to not stack overflow on many lets. (apach…
Browse files Browse the repository at this point in the history
…e#7558)

* [RELAY] Modify some passes to not stack overflow on many lets.

Passes modified:
- inline primitives
- dead code
- lambda lift

* one fix

* small fix

* .at -> []

* fix
  • Loading branch information
tkonolige authored and trevor-m committed May 11, 2021
1 parent f689575 commit 79d5454
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 31 deletions.
3 changes: 2 additions & 1 deletion include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
ICHECK(n.defined());
ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may "
"have generated invalid data.";
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
Expand Down
15 changes: 10 additions & 5 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,16 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
CompileMatch(match);
}

void VisitExpr_(const LetNode* let_node) {
DLOG(INFO) << PrettyPrint(let_node->value);
this->VisitExpr(let_node->value);
var_register_map_.insert({let_node->var, this->last_register_});
this->VisitExpr(let_node->body);
void VisitExpr_(const LetNode* l) final {
Expr let_binding = GetRef<Expr>(l);
const LetNode* let;
while ((let = let_binding.as<LetNode>())) {
VisitExpr(let->value);
var_register_map_.insert({let->var, this->last_register_});
let_binding = let->body;
}

VisitExpr(let_binding);
}

void VisitExpr_(const TupleGetItemNode* get_node) {
Expand Down
15 changes: 13 additions & 2 deletions src/relay/backend/vm/inline_primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,19 @@ struct PrimitiveInliner : ExprMutator {
explicit PrimitiveInliner(const IRModule& module) : module_(module) {}

Expr VisitExpr_(const LetNode* let_node) {
var_map.insert({let_node->var, VisitExpr(let_node->value)});
return ExprMutator::VisitExpr_(let_node);
auto pre_visit = [this](const LetNode* op) {
var_map.insert({op->var, this->VisitExpr(op->value)});
};
auto post_visit = [this](const LetNode* op) {
// Rely on the Memoizer to cache pre-visit values
Expr value = this->VisitExpr(op->value);
// Visit body and cache the op
Expr body = this->VisitExpr(op->body);
auto expr = GetRef<Expr>(op);
this->memo_[expr] = Let(op->var, value, body);
};
ExpandANormalForm(let_node, pre_visit, post_visit);
return memo_[GetRef<Expr>(let_node)];
}

Expr VisitExpr_(const CallNode* call) {
Expand Down
35 changes: 23 additions & 12 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,30 @@ class LambdaLifter : public ExprMutator {
explicit LambdaLifter(const IRModule& module) : module_(module) {}

Expr VisitExpr_(const LetNode* let_node) final {
bool is_lambda = false;
if (auto func = let_node->value.as<FunctionNode>()) {
if (!func->HasNonzeroAttr(attr::kPrimitive)) {
is_lambda = true;
letrec_.push_back(let_node->var);
auto pre_visit = [this](const LetNode* op) {
bool is_lambda = false;
if (auto func = op->value.as<FunctionNode>()) {
if (!func->HasNonzeroAttr(attr::kPrimitive)) {
is_lambda = true;
this->letrec_.push_back(op->var);
}
}
}
auto value = VisitExpr(let_node->value);
if (is_lambda) {
letrec_.pop_back();
}
auto body = VisitExpr(let_node->body);
return Let(let_node->var, value, body);
Expr value = this->VisitExpr(op->value);

if (is_lambda) {
this->letrec_.pop_back();
}
};
auto post_visit = [this](const LetNode* op) {
// Rely on the Memoizer to cache pre-visit values
Expr value = this->VisitExpr(op->value);
// Visit body and cache the op
Expr body = this->VisitExpr(op->body);
auto expr = GetRef<Expr>(op);
this->memo_[expr] = Let(op->var, value, body);
};
ExpandANormalForm(let_node, pre_visit, post_visit);
return memo_[GetRef<Expr>(let_node)];
}

Expr VisitExpr_(const CallNode* call_node) final {
Expand Down
48 changes: 37 additions & 11 deletions src/relay/transforms/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,16 @@ class FindDef : private ExprVisitor {
VarMap<Expr> expr_map_;

void VisitExpr_(const LetNode* l) final {
ICHECK_EQ(expr_map_.count(l->var), 0);
expr_map_[l->var] = l->value;
VisitExpr(l->value);
VisitExpr(l->body);
auto pre_visit = [this](const LetNode* op) {
ICHECK_EQ(expr_map_.count(op->var), 0);
expr_map_[op->var] = op->value;
this->VisitExpr(op->value);
};
auto post_visit = [this](const LetNode* op) {
this->VisitExpr(op->body);
this->visit_counter_[op] += 1;
};
ExpandANormalForm(l, pre_visit, post_visit);
}

friend CalcDep;
Expand Down Expand Up @@ -81,12 +87,24 @@ class Eliminator : private ExprMutator {
}

Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
return Let(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
}
auto pre_visit = [this](const LetNode* op) {
if (HasLet(op->var)) {
Expr value = this->VisitExpr(op->value);
}
};
auto post_visit = [this](const LetNode* op) {
Expr body = this->VisitExpr(op->body);
auto expr = GetRef<Expr>(op);
Var v = op->var;
if (HasLet(v)) {
Expr value = this->VisitExpr(op->value);
this->memo_[expr] = Let(v, value, body);
} else {
this->memo_[expr] = body;
}
};
ExpandANormalForm(op, pre_visit, post_visit);
return memo_[GetRef<Expr>(op)];
}
};

Expand Down Expand Up @@ -121,7 +139,15 @@ class CalcDep : protected MixedModeVisitor {
}
}

void VisitExpr_(const LetNode* l) final { VisitExpr(l->body); }
void VisitExpr_(const LetNode* l) final {
Expr let_binding = GetRef<Expr>(l);
const LetNode* let;
while ((let = let_binding.as<LetNode>())) {
let_binding = let->body;
visit_counter_[l] += 1;
}
VisitExpr(let_binding);
}

void VisitExpr_(const VarNode* v) final {
Var var = GetRef<Var>(v);
Expand Down

0 comments on commit 79d5454

Please sign in to comment.