Skip to content

Commit

Permalink
[PASS] UnrollLoop
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 5, 2017
1 parent d89917b commit 9bcaeb0
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 9 deletions.
10 changes: 9 additions & 1 deletion include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Stmt ConvertSSA(Stmt stmt);
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);

/*!
* \brief inline all calls of f in stmt.
Expand Down Expand Up @@ -97,6 +97,13 @@ Stmt Inline(Stmt stmt,
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);

/*!
* \brief unroll the constant loops
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
*/
Stmt UnrollLoop(Stmt stmt, int max_auto_step);

/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down Expand Up @@ -153,6 +160,7 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
*/
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);


} // namespace ir
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args)
<< "not enough argument passed, "
<< num_args << " passed"
<< "but request arg" << i;
<< " but request arg[" << i << "].";
return TVMArgValue(values[i], type_codes[i]);
}

Expand Down
10 changes: 10 additions & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>

namespace tvm {
Expand All @@ -29,6 +30,14 @@ TVM_REGISTER_API(_pass_Equal)
}
});

TVM_REGISTER_API(_pass_PostOrderVisit)
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
f(n);
});
});

// make from two arguments
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
Expand All @@ -52,6 +61,7 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
Expand Down
22 changes: 18 additions & 4 deletions src/pass/inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,24 @@ class IRInline : public IRMutator {
if (op->func == f_) {
CHECK_EQ(op->value_index, 0);
Expr expr = body_;
CHECK_EQ(args_.size(), op->args.size())
<< op->args.size() << " vs " << args_.size();
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
CHECK_EQ(args_.size(), op->args.size());

bool has_side_effect = false;
for (size_t i = 0; i < op->args.size(); ++i) {
if (HasSideEffect(op->args[i])) has_side_effect = true;
}

if (has_side_effect) {
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
}
} else {
Map<Var, Expr> vmap;
for (size_t i = 0; i < args_.size(); ++i) {
vmap.Set(args_[i], op->args[i]);
}
expr = Substitute(
Evaluate::make(expr), vmap).as<Evaluate>()->value;
}
return expr;
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/pass/simple_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ class IRSubstitue : public IRMutator {
std::unordered_map<const Variable*, Expr> smap;
};

Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map) {
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
IRSubstitue m;
for (auto kv : value_map) {
m.smap[kv.first->var.get()] = kv.second;
m.smap[kv.first.get()] = kv.second;
}
return m.Mutate(stmt);
}
Expand Down
77 changes: 77 additions & 0 deletions src/pass/unroll_loop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*!
* Copyright (c) 2016 by Contributors
* SSA related checks and pass.
* \file ssa.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "../schedule/compute_expr.h"

namespace tvm {
namespace ir {

class LoopUnroller : public IRMutator {
public:
explicit LoopUnroller(int max_auto_step)
: max_auto_step_(max_auto_step) {
}

Stmt Mutate_(const For* op, const Stmt& s) {
Stmt stmt = s;
// constant folding.
Expr extent = ir::Simplify(op->extent);
const IntImm* v1 = extent.as<IntImm>();
const UIntImm* v2 = extent.as<UIntImm>();
int value = -1;
if (v1 != nullptr) {
value = static_cast<int>(v1->value);
}
if (v2 != nullptr) {
value = static_cast<int>(v2->value);
}
bool allow_unroll = value >= 0 && value <= max_auto_step_;
if (op->for_type == ForType::Unrolled) {
CHECK_GE(value, 0)
<< "Cannot unroll non-constant loop";
allow_unroll = true;
}

if (allow_unroll) {
if (value == 0) return Evaluate::make(0);
Stmt body = op->body;
Map<Var, Expr> vmap;
Stmt unrolled;
for (int i = 0; i < value; ++i) {
Var lv(op->loop_var.node_);
vmap.Set(lv,
schedule::ComputeExpr<Add>(
op->min, make_const(op->loop_var.type(), i)));
Stmt step = Substitute(body, vmap);
if (unrolled.defined()) {
unrolled = Block::make(unrolled, step);
} else {
unrolled = step;
}
}
return this->Mutate(unrolled);
} else {
return IRMutator::Mutate_(op, stmt);
}
}

private:
int max_auto_step_;
};


Stmt UnrollLoop(Stmt stmt, int max_auto_step) {
Stmt ret = LoopUnroller(max_auto_step).Mutate(stmt);
return ConvertSSA(ret);
}

} // namespace ir
} // namespace tvm
10 changes: 9 additions & 1 deletion src/schedule/schedule_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,15 @@ MakeLoopNest(const Stage& sch,
return nest;
}

Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
Map<Var, Expr> temp;
for (const auto& kv : value_map) {
temp.Set(kv.first->var, kv.second);
}
return ir::Substitute(s, temp);
}

Stmt MakeLoop(const Stage& s,
const Map<IterVar, Range>& dom_map,
Stmt provide,
Expand All @@ -244,7 +253,6 @@ Stmt MakeLoop(const Stage& s,
auto nest = MakeLoopNest(s, dom_map, 0, false,
bound_state, {}, &value_map);


provide = Substitute(provide, value_map);
if (init.defined()) {
// try to find the location to insert the initialization.
Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_pass_unroll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import tvm

def test_unroll_loop():
dtype = 'int64'
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), dtype)
i = tvm.Var('i')
j = tvm.Var('j')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, n, 2, 0, 0,
tvm.make.For(j, 0, n, 0, 0,
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1)))
stmt = tvm.ir_pass.UnrollLoop(stmt, 8)
print(stmt)

if __name__ == "__main__":
test_unroll_loop()

0 comments on commit 9bcaeb0

Please sign in to comment.