From eb403cc8a05dae0a954849ec2919ea1c684ac9b4 Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Thu, 18 Nov 2021 17:16:08 -0800 Subject: [PATCH] [checkpoint] BOYC + dyn shapes unit test passes --- NEWS.md | 2 +- src/relay/backend/te_compiler.cc | 5 +- src/relay/backend/vm/compiler.cc | 20 ++-- src/relay/transforms/dead_code.cc | 170 ++++++++++++++---------------- src/runtime/vm/executable.cc | 8 +- 5 files changed, 100 insertions(+), 105 deletions(-) diff --git a/NEWS.md b/NEWS.md index d9397d83b772f..cfd51173aff47 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1734,7 +1734,7 @@ Rust language support in TVM includes two parts. 1. The frontend wraps the curre * [Relay] Fix memory leak in the interpreter (#4155) * [rpc] use callback func to do send & recv (#4147) * Add `lift_if_then_else` pass to improve loop partitioning (#3865) -* Decrease the complexity of CalcUse from exponential to linear (#4053) +* Decrease the complexity of UsageVisitor from exponential to linear (#4053) * [IR] Make iterators compatible with constructors of STL containers (#3624) * [Relay][Pass] Avoid FoldConstant folding some ops (#4245) * [Relay][Prelude] More dtypes support in `tensor_t` (#4233) diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 436847b7621eb..9de4f223210e3 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -297,11 +297,14 @@ class TECompilerImpl : public TECompilerNode { return GetUniqueName(mangled, &name_map_); }); - // Skip lowering for device copy node. + // Skip lowering for device copy node, but make the special __copy hacked function call + // appear to resolve to an external. + // TODO(mbs): device_copy cleanup const Expr body = (key->source_func)->body; if (auto* call_node = body.as()) { if (call_node->attrs.as()) { VLOG(1) << "is a device_copy"; + value->cached_func->funcs->Add(value->cached_func->prim_fn_var, key->source_func); return value; } } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index f66505a4322b9..ca09d80d46596 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -871,10 +871,8 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) // Run the optimizations necessary to target the VM. context_.module = OptimizeModuleImpl(std::move(mod)); - // Populate the global map. - // - // This maps global variables bound to Functions to a global index in the VMFunction table. - // (Note that primitive functions have their own map). + // Build the map from global variables bound to Functions to a global index in the + // VMFunction table. size_t num_functions = PopulateGlobalMap(); // Next we get ready by allocating space for @@ -1100,14 +1098,16 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { } size_t VMCompiler::PopulateGlobalMap() { - size_t global_index = 0; - // Allocate a VMFunction index for every function - for (const auto& pair : context_.module->functions) { - if (pair.second.as()) { - context_.global_map.emplace(pair.first, global_index++); + // Allocate a VMFunction index for every Relay Function we could call. + // Excludes PrimFuncs and externs, which are managed by the primitive_map_. + for (const auto& kv : context_.module->functions) { + if (const auto* function_node = kv.second.as()) { + if (!function_node->GetAttr(attr::kExternalSymbol)) { + context_.global_map.emplace(kv.first, context_.global_map.size()); + } } } - return global_index; + return context_.global_map.size(); } void VMCompiler::Codegen() { diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index dcd35811c7d6d..357e03c47f78d 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -326,127 +326,125 @@ class PurityVisitor : ExprFunctor { std::unordered_map global_var_to_purity_; }; -/*! \brief Accumulate all the let-bound values. */ -class FindDef : public ExprVisitor { +/*! \brief Accumulate the bound values and usage count for each let-bound variable. */ +class UsageVisitor : public ExprVisitor { public: std::unordered_map let_bound_values_; - - private: - void VisitExpr_(const LetNode* l) final { - auto pre_visit = [this](const LetNode* op) { - ICHECK_EQ(let_bound_values_.count(op->var.get()), 0); - let_bound_values_[op->var.get()] = op->value; - VisitExpr(op->value); - }; - auto post_visit = [this](const LetNode* op) { - VisitExpr(op->body); - visit_counter_[op] += 1; - }; - ExpandANormalForm(l, pre_visit, post_visit); - } -}; - -/*! \brief Accumulate the usage count for each let-bound variable. */ -class CalcUse : public MixedModeVisitor { - public: std::unordered_map use_map_; - explicit CalcUse(const std::unordered_map* let_bound_values) - : MixedModeVisitor(2), let_bound_values_(let_bound_values) {} + explicit UsageVisitor(const std::unordered_map* var_to_purity) + : var_to_purity_(var_to_purity) {} - using MixedModeVisitor::VisitExpr_; - - void VisitLeaf(const Expr& e) final { - visit_counter_[e.get()]++; - // The dce code seprate variable into three parts: - // used 0 times (remove) - // used 1 times (inline) - // used 2 times (dont do anything). - if (visit_counter_[e.get()] <= 2) { - using TParent = ExprFunctor; - TParent::VisitExpr(e); + void VisitExpr(const Expr& expr) final { + // Once we've seen 2 usages of a variable we know it can be neither elided nor inlined. + if (++visit_counter_[expr.get()] <= 2) { + ExprFunctor::VisitExpr(expr); } } - void VisitExpr_(const LetNode* l) final { - Expr let_binding = GetRef(l); - const LetNode* let; - while ((let = let_binding.as())) { - let_binding = let->body; - visit_counter_[l] += 1; + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + ++visit_counter_[inner_let_node]; + let_bound_values_[inner_let_node->var.get()] = inner_let_node->value; + ICHECK(var_to_purity_->count(inner_let_node->var.get())); + if (var_to_purity_->at(inner_let_node->var.get())) { + // We'll defer visiting the let-bound value until we've seen a use of the variable. + // no-op. + } else { + // The let-bound value is impure so must visit now. + VisitExpr(inner_let_node->value); + } + expr = inner_let_node->body; } - VisitExpr(let_binding); + VisitExpr(expr); } - void VisitExpr_(const VarNode* v) final { - Var var = GetRef(v); - ++use_map_[var.get()]; - if (use_map_[var.get()] == 1 && let_bound_values_->count(var.get()) > 0) { - VisitExpr(let_bound_values_->at(var.get())); + void VisitExpr_(const VarNode* var_node) final { + if (let_bound_values_.count(var_node)) { + size_t& n = use_map_[var_node]; + ++n; + VLOG(1) << var_node->name_hint() << " = " << n; + ICHECK(var_to_purity_->count(var_node)); + if (n == 1 && var_to_purity_->at(var_node)) { + // Now that we have at least one use of the let-bound var, we know the let-bound + // value is necessary. + VisitExpr(let_bound_values_[var_node]); + } } } - const std::unordered_map* let_bound_values_; + const std::unordered_map* var_to_purity_; }; /*! \brief Eliminate/inline let-bound values where safe to do so. */ -class Eliminator : public ExprMutator { +class EliminatorMutator : public ExprMutator { public: - Eliminator(bool inline_once, const std::unordered_map* let_bound_values, - const std::unordered_map* use_map, - const std::unordered_map* var_to_purity) + EliminatorMutator(bool inline_once, + const std::unordered_map* let_bound_values, + const std::unordered_map* use_map, + const std::unordered_map* var_to_purity) : inline_once_(inline_once), let_bound_values_(let_bound_values), use_map_(use_map), var_to_purity_(var_to_purity) {} private: - bool RetainLetBinding(const VarNode* var_node) { + enum Action { kElide, kInline, kNoChange }; + + /*! \brief What should we do with let-binding for \p var_node? */ + Action ActionFor(const VarNode* var_node) { + if (let_bound_values_->count(var_node) == 0) { + // Not let-bound var. + return kNoChange; + } auto itr = var_to_purity_->find(var_node); ICHECK(itr != var_to_purity_->end()) << PrettyPrint(GetRef(var_node)); if (!itr->second) { - // The let-bound value is impure -- we must retain it. - return true; + // The let-bound value is impure -- we must leave it exactly where it is. + return kNoChange; } switch (use_map_->count(var_node) ? use_map_->at(var_node) : 0) { case 0: - // We can elide pure and unused variables. - return false; + return kElide; case 1: - // Can elide if we know the let-bound pure expression will be inlined to its unique use - // site. - return !inline_once_; + return inline_once_ ? kInline : kNoChange; default: - return true; + return kNoChange; } } Expr VisitExpr_(const VarNode* var_node) final { - if (let_bound_values_->count(var_node) && !RetainLetBinding(var_node)) { - // Safe to inline. + if (ActionFor(var_node) == kInline) { VLOG(1) << "inlining let-bound variable:" << std::endl << PrettyPrint(GetRef(var_node)); return VisitExpr(let_bound_values_->at(var_node)); + } else { + return GetRef(var_node); } - return GetRef(var_node); } Expr VisitExpr_(const LetNode* op) final { auto pre_visit = [this](const LetNode* op) { - if (RetainLetBinding(op->var.get())) { - Expr value = VisitExpr(op->value); + if (ActionFor(op->var.get()) != kElide) { + (void)VisitExpr(op->value); } }; auto post_visit = [this](const LetNode* op) { Expr body = VisitExpr(op->body); auto expr = GetRef(op); - if (RetainLetBinding(op->var.get())) { - Expr value = VisitExpr(op->value); - this->memo_[expr] = Let(op->var, value, body); - } else { - // Safe to elide. - VLOG(1) << "eliding let-binding for unused/inlined variable:" << std::endl - << PrettyPrint(op->var); - this->memo_[expr] = body; + switch (ActionFor(op->var.get())) { + case kElide: + VLOG(1) << "eliding let-bound variable:" << std::endl << PrettyPrint(op->var); + memo_[expr] = body; + break; + case kInline: + // Already inlined at use-side. + memo_[expr] = body; + break; + case kNoChange: + Expr value = VisitExpr(op->value); + memo_[expr] = Let(op->var, value, body); + break; } }; ExpandANormalForm(op, pre_visit, post_visit); @@ -468,9 +466,9 @@ Pass DeadCodeElimination(bool inline_once) { std::unordered_map var_to_purity; { VLOG(1) << "determine purity"; - PurityVisitor is_pure_visitor(mod); - is_pure_visitor.VisitModule(); - var_to_purity = is_pure_visitor.GetPurityMap(); + PurityVisitor purity_visitor(mod); + purity_visitor.VisitModule(); + var_to_purity = purity_visitor.GetPurityMap(); } IRModule result(/*functions=*/{}, mod->type_definitions, mod->Imports(), mod->source_map); @@ -480,23 +478,15 @@ Pass DeadCodeElimination(bool inline_once) { VLOG(1) << "processing " << PrettyPrint(kv.first); - // Collect the let-bound values in case we need to inline them. - VLOG(1) << "find let-bound values"; - FindDef fd; - fd.VisitExpr(function_node->body); - - VLOG(1) << "count let-bound variables"; - CalcUse cd(&fd.let_bound_values_); - cd.VisitExpr(function_node->body); + VLOG(1) << "count usage"; + UsageVisitor usage_visitor(&var_to_purity); + usage_visitor.VisitExpr(function); // Actually eliminate/inline the let-bindings. VLOG(1) << "eliminate"; - Eliminator el(inline_once, &fd.let_bound_values_, &cd.use_map_, &var_to_purity); - auto new_body = el.VisitExpr(function_node->body); - - auto new_function = Function(function->params, new_body, function->ret_type, - function->type_params, function->attrs, function->span); - result->Add(kv.first, new_function); + EliminatorMutator eliminator_mutator(inline_once, &usage_visitor.let_bound_values_, + &usage_visitor.use_map_, &var_to_purity); + result->Add(kv.first, Downcast(eliminator_mutator.VisitExpr(function))); } else { // PrimFuncs come across unchanged. result->Add(kv.first, kv.second); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 0d8ca3ed39275..b613a03bfc5cd 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -143,10 +143,12 @@ std::string Executable::GetBytecode() const { for (size_t idx = 0; idx < func.instructions.size(); ++idx) { const auto& instr = func.instructions[idx]; const auto& serialized_instr = SerializeInstruction(instr); - oss << std::setw(2) << idx << ": " << serialized_instr.opcode << " "; + std::ostringstream line; + line << std::setw(2) << idx << ": " << serialized_instr.opcode << " "; for (auto it : serialized_instr.fields) { - oss << it << " "; + line << it << " "; } + oss << std::setw(40) << std::setfill(' ') << std::left << line.str(); oss << " # " << instr; if (oss.str().back() != '\n') oss << std::endl; } @@ -204,7 +206,7 @@ std::string Executable::GetPrimitives() const { return left.first < right.first; }); for (const auto& entry : entries) { - os << "VM Primitive[" << entry.first << "]: " << entry.second << std::endl; + os << "VM PackedFunc[" << entry.first << "]: " << entry.second << std::endl; } return os.str(); }