Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] Modify some passes to not stack overflow on many lets. #7558

Merged
merged 5 commits into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
ICHECK(n.defined());
ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may "
"have generated invalid data.";
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
Expand Down
15 changes: 10 additions & 5 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,16 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
CompileMatch(match);
}

void VisitExpr_(const LetNode* let_node) {
DLOG(INFO) << PrettyPrint(let_node->value);
this->VisitExpr(let_node->value);
var_register_map_.insert({let_node->var, this->last_register_});
this->VisitExpr(let_node->body);
void VisitExpr_(const LetNode* l) final {
Expr let_binding = GetRef<Expr>(l);
const LetNode* let;
while ((let = let_binding.as<LetNode>())) {
VisitExpr(let->value);
var_register_map_.insert({let->var, this->last_register_});
let_binding = let->body;
}

VisitExpr(let_binding);
}

void VisitExpr_(const TupleGetItemNode* get_node) {
Expand Down
15 changes: 13 additions & 2 deletions src/relay/backend/vm/inline_primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,19 @@ struct PrimitiveInliner : ExprMutator {
explicit PrimitiveInliner(const IRModule& module) : module_(module) {}

Expr VisitExpr_(const LetNode* let_node) {
var_map.insert({let_node->var, VisitExpr(let_node->value)});
return ExprMutator::VisitExpr_(let_node);
auto pre_visit = [this](const LetNode* op) {
var_map.insert({op->var, this->VisitExpr(op->value)});
};
auto post_visit = [this](const LetNode* op) {
// Rely on the Memoizer to cache pre-visit values
Expr value = this->VisitExpr(op->value);
// Visit body and cache the op
Expr body = this->VisitExpr(op->body);
auto expr = GetRef<Expr>(op);
this->memo_[expr] = Let(op->var, value, body);
};
ExpandANormalForm(let_node, pre_visit, post_visit);
return memo_[GetRef<Expr>(let_node)];
}

Expr VisitExpr_(const CallNode* call) {
Expand Down
35 changes: 23 additions & 12 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,30 @@ class LambdaLifter : public ExprMutator {
explicit LambdaLifter(const IRModule& module) : module_(module) {}

Expr VisitExpr_(const LetNode* let_node) final {
bool is_lambda = false;
if (auto func = let_node->value.as<FunctionNode>()) {
if (!func->HasNonzeroAttr(attr::kPrimitive)) {
is_lambda = true;
letrec_.push_back(let_node->var);
auto pre_visit = [this](const LetNode* op) {
bool is_lambda = false;
if (auto func = op->value.as<FunctionNode>()) {
if (!func->HasNonzeroAttr(attr::kPrimitive)) {
is_lambda = true;
this->letrec_.push_back(op->var);
}
}
}
auto value = VisitExpr(let_node->value);
if (is_lambda) {
letrec_.pop_back();
}
auto body = VisitExpr(let_node->body);
return Let(let_node->var, value, body);
Expr value = this->VisitExpr(op->value);

if (is_lambda) {
this->letrec_.pop_back();
}
};
auto post_visit = [this](const LetNode* op) {
// Rely on the Memoizer to cache pre-visit values
Expr value = this->VisitExpr(op->value);
// Visit body and cache the op
Expr body = this->VisitExpr(op->body);
auto expr = GetRef<Expr>(op);
this->memo_[expr] = Let(op->var, value, body);
};
ExpandANormalForm(let_node, pre_visit, post_visit);
return memo_[GetRef<Expr>(let_node)];
}

Expr VisitExpr_(const CallNode* call_node) final {
Expand Down
44 changes: 32 additions & 12 deletions src/relay/transforms/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,16 @@ class FindDef : private ExprVisitor {
VarMap<Expr> expr_map_;

void VisitExpr_(const LetNode* l) final {
ICHECK_EQ(expr_map_.count(l->var), 0);
expr_map_[l->var] = l->value;
VisitExpr(l->value);
VisitExpr(l->body);
auto pre_visit = [this](const LetNode* op) {
ICHECK_EQ(expr_map_.count(op->var), 0);
expr_map_[op->var] = op->value;
this->VisitExpr(op->value);
};
auto post_visit = [this](const LetNode* op) {
this->VisitExpr(op->body);
this->visit_counter_[op] += 1;
};
ExpandANormalForm(l, pre_visit, post_visit);
}

friend CalcDep;
Expand Down Expand Up @@ -77,16 +83,26 @@ class Eliminator : private ExprMutator {

Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_.at(v));
tkonolige marked this conversation as resolved.
Show resolved Hide resolved
}

Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
return Let(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
}
auto pre_visit = [this](const LetNode* op) { Expr value = this->VisitExpr(op->value); };
auto post_visit = [this](const LetNode* op) {
// Rely on the Memoizer to cache pre-visit values
Expr value = this->VisitExpr(op->value);
// Visit body and cache the op
Expr body = this->VisitExpr(op->body);
auto expr = GetRef<Expr>(op);
Var v = op->var;
if (HasLet(v)) {
this->memo_[expr] = Let(v, value, body);
} else {
this->memo_[expr] = body;
}
};
ExpandANormalForm(op, pre_visit, post_visit);
return memo_.at(GetRef<Expr>(op));
tkonolige marked this conversation as resolved.
Show resolved Hide resolved
}
};

Expand Down Expand Up @@ -121,7 +137,11 @@ class CalcDep : protected MixedModeVisitor {
}
}

void VisitExpr_(const LetNode* l) final { VisitExpr(l->body); }
void VisitExpr_(const LetNode* l) final {
auto pre_visit = [](const LetNode* op) {};
auto post_visit = [this](const LetNode* op) { this->VisitExpr(op->body); };
ExpandANormalForm(l, pre_visit, post_visit);
}

void VisitExpr_(const VarNode* v) final {
Var var = GetRef<Var>(v);
Expand Down