Skip to content

Commit

Permalink
[checkpoint] BOYC + dyn shapes unit test passes
Browse files Browse the repository at this point in the history
  • Loading branch information
mbs-octoml committed Nov 19, 2021
1 parent 845f516 commit eb403cc
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 105 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -1734,7 +1734,7 @@ Rust language support in TVM includes two parts. 1. The frontend wraps the curre
* [Relay] Fix memory leak in the interpreter (#4155)
* [rpc] use callback func to do send & recv (#4147)
* Add `lift_if_then_else` pass to improve loop partitioning (#3865)
* Decrease the complexity of CalcUse from exponential to linear (#4053)
* Decrease the complexity of UsageVisitor from exponential to linear (#4053)
* [IR] Make iterators compatible with constructors of STL containers (#3624)
* [Relay][Pass] Avoid FoldConstant folding some ops (#4245)
* [Relay][Prelude] More dtypes support in `tensor_t` (#4233)
Expand Down
5 changes: 4 additions & 1 deletion src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,14 @@ class TECompilerImpl : public TECompilerNode {
return GetUniqueName(mangled, &name_map_);
});

// Skip lowering for device copy node.
// Skip lowering for device copy node, but make the special __copy hacked function call
// appear to resolve to an external.
// TODO(mbs): device_copy cleanup
const Expr body = (key->source_func)->body;
if (auto* call_node = body.as<CallNode>()) {
if (call_node->attrs.as<DeviceCopyAttrs>()) {
VLOG(1) << "is a device_copy";
value->cached_func->funcs->Add(value->cached_func->prim_fn_var, key->source_func);
return value;
}
}
Expand Down
20 changes: 10 additions & 10 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -871,10 +871,8 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host)
// Run the optimizations necessary to target the VM.
context_.module = OptimizeModuleImpl(std::move(mod));

// Populate the global map.
//
// This maps global variables bound to Functions to a global index in the VMFunction table.
// (Note that primitive functions have their own map).
// Build the map from global variables bound to Functions to a global index in the
// VMFunction table.
size_t num_functions = PopulateGlobalMap();

// Next we get ready by allocating space for
Expand Down Expand Up @@ -1100,14 +1098,16 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
}

size_t VMCompiler::PopulateGlobalMap() {
size_t global_index = 0;
// Allocate a VMFunction index for every function
for (const auto& pair : context_.module->functions) {
if (pair.second.as<FunctionNode>()) {
context_.global_map.emplace(pair.first, global_index++);
// Allocate a VMFunction index for every Relay Function we could call.
// Excludes PrimFuncs and externs, which are managed by the primitive_map_.
for (const auto& kv : context_.module->functions) {
if (const auto* function_node = kv.second.as<FunctionNode>()) {
if (!function_node->GetAttr<String>(attr::kExternalSymbol)) {
context_.global_map.emplace(kv.first, context_.global_map.size());
}
}
}
return global_index;
return context_.global_map.size();
}

void VMCompiler::Codegen() {
Expand Down
170 changes: 80 additions & 90 deletions src/relay/transforms/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,127 +326,125 @@ class PurityVisitor : ExprFunctor<Purity(const Expr&)> {
std::unordered_map<const GlobalVarNode*, Purity> global_var_to_purity_;
};

/*! \brief Accumulate all the let-bound values. */
class FindDef : public ExprVisitor {
/*! \brief Accumulate the bound values and usage count for each let-bound variable. */
class UsageVisitor : public ExprVisitor {
public:
std::unordered_map<const VarNode*, Expr> let_bound_values_;

private:
void VisitExpr_(const LetNode* l) final {
auto pre_visit = [this](const LetNode* op) {
ICHECK_EQ(let_bound_values_.count(op->var.get()), 0);
let_bound_values_[op->var.get()] = op->value;
VisitExpr(op->value);
};
auto post_visit = [this](const LetNode* op) {
VisitExpr(op->body);
visit_counter_[op] += 1;
};
ExpandANormalForm(l, pre_visit, post_visit);
}
};

/*! \brief Accumulate the usage count for each let-bound variable. */
class CalcUse : public MixedModeVisitor {
public:
std::unordered_map<const VarNode*, size_t> use_map_;

explicit CalcUse(const std::unordered_map<const VarNode*, Expr>* let_bound_values)
: MixedModeVisitor(2), let_bound_values_(let_bound_values) {}
explicit UsageVisitor(const std::unordered_map<const VarNode*, bool>* var_to_purity)
: var_to_purity_(var_to_purity) {}

using MixedModeVisitor::VisitExpr_;

void VisitLeaf(const Expr& e) final {
visit_counter_[e.get()]++;
// The dce code seprate variable into three parts:
// used 0 times (remove)
// used 1 times (inline)
// used 2 times (dont do anything).
if (visit_counter_[e.get()] <= 2) {
using TParent = ExprFunctor<void(const Expr&)>;
TParent::VisitExpr(e);
void VisitExpr(const Expr& expr) final {
// Once we've seen 2 usages of a variable we know it can be neither elided nor inlined.
if (++visit_counter_[expr.get()] <= 2) {
ExprFunctor<void(const Expr&)>::VisitExpr(expr);
}
}

void VisitExpr_(const LetNode* l) final {
Expr let_binding = GetRef<Expr>(l);
const LetNode* let;
while ((let = let_binding.as<LetNode>())) {
let_binding = let->body;
visit_counter_[l] += 1;
void VisitExpr_(const LetNode* let_node) final {
Expr expr = GetRef<Expr>(let_node);
while (const auto* inner_let_node = expr.as<LetNode>()) {
++visit_counter_[inner_let_node];
let_bound_values_[inner_let_node->var.get()] = inner_let_node->value;
ICHECK(var_to_purity_->count(inner_let_node->var.get()));
if (var_to_purity_->at(inner_let_node->var.get())) {
// We'll defer visiting the let-bound value until we've seen a use of the variable.
// no-op.
} else {
// The let-bound value is impure so must visit now.
VisitExpr(inner_let_node->value);
}
expr = inner_let_node->body;
}
VisitExpr(let_binding);
VisitExpr(expr);
}

void VisitExpr_(const VarNode* v) final {
Var var = GetRef<Var>(v);
++use_map_[var.get()];
if (use_map_[var.get()] == 1 && let_bound_values_->count(var.get()) > 0) {
VisitExpr(let_bound_values_->at(var.get()));
void VisitExpr_(const VarNode* var_node) final {
if (let_bound_values_.count(var_node)) {
size_t& n = use_map_[var_node];
++n;
VLOG(1) << var_node->name_hint() << " = " << n;
ICHECK(var_to_purity_->count(var_node));
if (n == 1 && var_to_purity_->at(var_node)) {
// Now that we have at least one use of the let-bound var, we know the let-bound
// value is necessary.
VisitExpr(let_bound_values_[var_node]);
}
}
}

const std::unordered_map<const VarNode*, Expr>* let_bound_values_;
const std::unordered_map<const VarNode*, bool>* var_to_purity_;
};

/*! \brief Eliminate/inline let-bound values where safe to do so. */
class Eliminator : public ExprMutator {
class EliminatorMutator : public ExprMutator {
public:
Eliminator(bool inline_once, const std::unordered_map<const VarNode*, Expr>* let_bound_values,
const std::unordered_map<const VarNode*, size_t>* use_map,
const std::unordered_map<const VarNode*, bool>* var_to_purity)
EliminatorMutator(bool inline_once,
const std::unordered_map<const VarNode*, Expr>* let_bound_values,
const std::unordered_map<const VarNode*, size_t>* use_map,
const std::unordered_map<const VarNode*, bool>* var_to_purity)
: inline_once_(inline_once),
let_bound_values_(let_bound_values),
use_map_(use_map),
var_to_purity_(var_to_purity) {}

private:
bool RetainLetBinding(const VarNode* var_node) {
enum Action { kElide, kInline, kNoChange };

/*! \brief What should we do with let-binding for \p var_node? */
Action ActionFor(const VarNode* var_node) {
if (let_bound_values_->count(var_node) == 0) {
// Not let-bound var.
return kNoChange;
}
auto itr = var_to_purity_->find(var_node);
ICHECK(itr != var_to_purity_->end()) << PrettyPrint(GetRef<Var>(var_node));
if (!itr->second) {
// The let-bound value is impure -- we must retain it.
return true;
// The let-bound value is impure -- we must leave it exactly where it is.
return kNoChange;
}
switch (use_map_->count(var_node) ? use_map_->at(var_node) : 0) {
case 0:
// We can elide pure and unused variables.
return false;
return kElide;
case 1:
// Can elide if we know the let-bound pure expression will be inlined to its unique use
// site.
return !inline_once_;
return inline_once_ ? kInline : kNoChange;
default:
return true;
return kNoChange;
}
}

Expr VisitExpr_(const VarNode* var_node) final {
if (let_bound_values_->count(var_node) && !RetainLetBinding(var_node)) {
// Safe to inline.
if (ActionFor(var_node) == kInline) {
VLOG(1) << "inlining let-bound variable:" << std::endl << PrettyPrint(GetRef<Var>(var_node));
return VisitExpr(let_bound_values_->at(var_node));
} else {
return GetRef<Var>(var_node);
}
return GetRef<Var>(var_node);
}

Expr VisitExpr_(const LetNode* op) final {
auto pre_visit = [this](const LetNode* op) {
if (RetainLetBinding(op->var.get())) {
Expr value = VisitExpr(op->value);
if (ActionFor(op->var.get()) != kElide) {
(void)VisitExpr(op->value);
}
};
auto post_visit = [this](const LetNode* op) {
Expr body = VisitExpr(op->body);
auto expr = GetRef<Expr>(op);
if (RetainLetBinding(op->var.get())) {
Expr value = VisitExpr(op->value);
this->memo_[expr] = Let(op->var, value, body);
} else {
// Safe to elide.
VLOG(1) << "eliding let-binding for unused/inlined variable:" << std::endl
<< PrettyPrint(op->var);
this->memo_[expr] = body;
switch (ActionFor(op->var.get())) {
case kElide:
VLOG(1) << "eliding let-bound variable:" << std::endl << PrettyPrint(op->var);
memo_[expr] = body;
break;
case kInline:
// Already inlined at use-side.
memo_[expr] = body;
break;
case kNoChange:
Expr value = VisitExpr(op->value);
memo_[expr] = Let(op->var, value, body);
break;
}
};
ExpandANormalForm(op, pre_visit, post_visit);
Expand All @@ -468,9 +466,9 @@ Pass DeadCodeElimination(bool inline_once) {
std::unordered_map<const VarNode*, bool> var_to_purity;
{
VLOG(1) << "determine purity";
PurityVisitor is_pure_visitor(mod);
is_pure_visitor.VisitModule();
var_to_purity = is_pure_visitor.GetPurityMap();
PurityVisitor purity_visitor(mod);
purity_visitor.VisitModule();
var_to_purity = purity_visitor.GetPurityMap();
}

IRModule result(/*functions=*/{}, mod->type_definitions, mod->Imports(), mod->source_map);
Expand All @@ -480,23 +478,15 @@ Pass DeadCodeElimination(bool inline_once) {

VLOG(1) << "processing " << PrettyPrint(kv.first);

// Collect the let-bound values in case we need to inline them.
VLOG(1) << "find let-bound values";
FindDef fd;
fd.VisitExpr(function_node->body);

VLOG(1) << "count let-bound variables";
CalcUse cd(&fd.let_bound_values_);
cd.VisitExpr(function_node->body);
VLOG(1) << "count usage";
UsageVisitor usage_visitor(&var_to_purity);
usage_visitor.VisitExpr(function);

// Actually eliminate/inline the let-bindings.
VLOG(1) << "eliminate";
Eliminator el(inline_once, &fd.let_bound_values_, &cd.use_map_, &var_to_purity);
auto new_body = el.VisitExpr(function_node->body);

auto new_function = Function(function->params, new_body, function->ret_type,
function->type_params, function->attrs, function->span);
result->Add(kv.first, new_function);
EliminatorMutator eliminator_mutator(inline_once, &usage_visitor.let_bound_values_,
&usage_visitor.use_map_, &var_to_purity);
result->Add(kv.first, Downcast<Function>(eliminator_mutator.VisitExpr(function)));
} else {
// PrimFuncs come across unchanged.
result->Add(kv.first, kv.second);
Expand Down
8 changes: 5 additions & 3 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,12 @@ std::string Executable::GetBytecode() const {
for (size_t idx = 0; idx < func.instructions.size(); ++idx) {
const auto& instr = func.instructions[idx];
const auto& serialized_instr = SerializeInstruction(instr);
oss << std::setw(2) << idx << ": " << serialized_instr.opcode << " ";
std::ostringstream line;
line << std::setw(2) << idx << ": " << serialized_instr.opcode << " ";
for (auto it : serialized_instr.fields) {
oss << it << " ";
line << it << " ";
}
oss << std::setw(40) << std::setfill(' ') << std::left << line.str();
oss << " # " << instr;
if (oss.str().back() != '\n') oss << std::endl;
}
Expand Down Expand Up @@ -204,7 +206,7 @@ std::string Executable::GetPrimitives() const {
return left.first < right.first;
});
for (const auto& entry : entries) {
os << "VM Primitive[" << entry.first << "]: " << entry.second << std::endl;
os << "VM PackedFunc[" << entry.first << "]: " << entry.second << std::endl;
}
return os.str();
}
Expand Down

0 comments on commit eb403cc

Please sign in to comment.