Skip to content

Commit

Permalink
Address Christopher's comments from #8788 (#9197)
Browse files Browse the repository at this point in the history
We don't need the Optional<IRModule> on ToANormalForm and friends.
  • Loading branch information
mbs-octoml authored Oct 6, 2021
1 parent c9c0688 commit 01771ab
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 21 deletions.
3 changes: 1 addition & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,11 @@ TVM_DLL Pass ToANormalForm();
/*!
* \brief ToANormalForm but on incomplete graph.
*
* \param maybe_mod optional module holding definitions for global vars in \p expr
* \param expr the graph.
*
* \return The transformed program.
*/
TVM_DLL Expr ToANormalForm(const Optional<IRModule>& maybe_mod, const Expr& expr);
TVM_DLL Expr ToANormalForm(const Expr& expr);

/*!
* \brief Turn an expression into continuation passing style(CPS).
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/higher_order_gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ struct ReverseAD : ExprMutator {
return Call(bpv, {});
});
Expr nbp = Function({}, nbp_body, TupleType::Empty(), {});
ll->Push(RefWrite(bp, transform::ToANormalForm(mod, nbp)));
ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
// TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
return ret;
});
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg);
Scope LCA(Scope lhs, Scope rhs);

// For basic block normal form.
Expr ToBasicBlockNormalFormAux(const Optional<IRModule>& maybe_mod, const Expr& e);
Expr ToBasicBlockNormalFormAux(const Expr& e);

// ToANormalForm for expressions and as a Pass are declared in transform.h

Expand Down
38 changes: 22 additions & 16 deletions src/relay/transforms/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,25 +149,31 @@ namespace {
*/
class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, private transform::LexicalOnDeviceMixin {
public:
static Expr ToANormalForm(const Optional<IRModule>& maybe_mod, const Expr& e,
const DependencyGraph& dg, NodeScopeMap* node_scope) {
Fill fi(maybe_mod, dg, node_scope, nullptr);
static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope) {
Fill fi(dg, node_scope, nullptr);
return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e));
}

// For basic block normal form, bind expressions only if the original expression's scope
// should be lifted
static Expr ToBasicBlockNormalForm(const Optional<IRModule>& maybe_mod, const Expr& e,
const DependencyGraph& dg, NodeScopeMap* node_scope,
ExprSet* lifted) {
Fill fi(maybe_mod, dg, node_scope, lifted);
static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg,
NodeScopeMap* node_scope, ExprSet* lifted) {
Fill fi(dg, node_scope, lifted);
return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e));
}

private:
Fill(const Optional<IRModule>& maybe_mod, const DependencyGraph& dg, NodeScopeMap* node_scope,
ExprSet* include_set)
: transform::LexicalOnDeviceMixin(maybe_mod),
// Note: Conversion to ANF needn't care about the devices for global vars since all that can
// happen with them is to go from:
// ...@g...
// to:
// let %x = @g;
// ...
// ...%x...
// In that case the code will ask for the device for @g, get kInvalidDeviceType, then
// MaybeOnDevice @g, which is always a no-op.
Fill(const DependencyGraph& dg, NodeScopeMap* node_scope, ExprSet* include_set)
: transform::LexicalOnDeviceMixin(Optional<IRModule>()),
dg_(dg),
node_scope_(node_scope),
include_set_(include_set) {}
Expand Down Expand Up @@ -373,7 +379,7 @@ IRModule ModuleToANormalForm(const IRModule& mod) {
if (const auto* n = it.second.as<FunctionNode>()) {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
Function func = GetRef<Function>(n);
Function ret = Downcast<Function>(transform::ToANormalForm(mod, func));
Function ret = Downcast<Function>(transform::ToANormalForm(func));
ICHECK_EQ(FreeVars(ret).size(), 0) << "rewritten:" << std::endl
<< PrettyPrint(ret) << std::endl
<< "should not have free vars: " << FreeVars(ret);
Expand All @@ -394,7 +400,7 @@ IRModule ModuleToANormalForm(const IRModule& mod) {

} // namespace

Expr ToBasicBlockNormalFormAux(const Optional<IRModule>& maybe_mod, const Expr& e) {
Expr ToBasicBlockNormalFormAux(const Expr& e) {
// calculate all the dependency between nodes.
support::Arena arena;
DependencyGraph dg = DependencyGraph::Create(&arena, e);
Expand All @@ -403,12 +409,12 @@ Expr ToBasicBlockNormalFormAux(const Optional<IRModule>& maybe_mod, const Expr&
* We also record the set of expressions whose scope is lifted.
*/
std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
return Fill::ToBasicBlockNormalForm(maybe_mod, e, dg, &scopes.first, &scopes.second);
return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);
}

namespace transform {

Expr ToANormalForm(const Optional<IRModule>& maybe_mod, const Expr& e) {
Expr ToANormalForm(const Expr& e) {
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
Expand All @@ -431,7 +437,7 @@ Expr ToANormalForm(const Optional<IRModule>& maybe_mod, const Expr& e) {
* We do an additional pass to fill all the LetList and we are done.
*/
std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
return Fill::ToANormalForm(maybe_mod, e, dg, &scopes.first);
return Fill::ToANormalForm(e, dg, &scopes.first);
}

Pass ToANormalForm() {
Expand All @@ -445,7 +451,7 @@ TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed([]() {
});

TVM_REGISTER_GLOBAL("relay._transform.ToANormalFormExpr").set_body_typed([](const Expr& e) {
return ToANormalForm(Optional<IRModule>(), e);
return ToANormalForm(e);
});

} // namespace transform
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/to_basic_block_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) {
if (const auto* n = it.second.as<FunctionNode>()) {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
Function func = GetRef<Function>(n);
Function ret = Downcast<Function>(ToBasicBlockNormalFormAux(mod, func));
Function ret = Downcast<Function>(ToBasicBlockNormalFormAux(func));
VLOG(1) << "rewritten:" << std::endl
<< PrettyPrint(func) << std::endl
<< "to BasicBlockANF:" << std::endl
Expand Down

0 comments on commit 01771ab

Please sign in to comment.