diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 8eaec0f523151..f8412dc3666b6 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -68,7 +68,7 @@ Stmt ConvertSSA(Stmt stmt); * \param value_map The map of new values. * \return The converted form. */ -Stmt Substitute(Stmt stmt, const Map& value_map); +Stmt Substitute(Stmt stmt, const Map& value_map); /*! * \brief inline all calls of f in stmt. @@ -97,6 +97,13 @@ Stmt Inline(Stmt stmt, Stmt StorageFlatten(Stmt stmt, Map extern_buffer); +/*! + * \brief unroll the constant loops + * \param stmt The statment to be unrolled. + * \param max_auto_step The maximum step to stop performing automatic unrolling. + */ +Stmt UnrollLoop(Stmt stmt, int max_auto_step); + /*! * \brief Make an user callable API LoweredFunc. * @@ -153,6 +160,7 @@ Array SplitHostDevice(LoweredFunc func); */ LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); + } // namespace ir } // namespace tvm diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index eafc367fe3c59..3b1921ee88685 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -562,7 +562,7 @@ inline TVMArgValue TVMArgs::operator[](int i) const { CHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed" - << "but request arg" << i; + << " but request arg[" << i << "]."; return TVMArgValue(values[i], type_codes[i]); } diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 6e7bbd8491713..df79996e4a6f7 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include namespace tvm { @@ -29,6 +30,14 @@ TVM_REGISTER_API(_pass_Equal) } }); +TVM_REGISTER_API(_pass_PostOrderVisit) +.set_body([](TVMArgs args, TVMRetValue *ret) { + PackedFunc f = args[1]; + ir::PostOrderVisit(args[0], [f](const NodeRef& n) { + f(n); + }); + }); + // make from two arguments #define REGISTER_PASS1(PassName) \ TVM_REGISTER_API(_pass_## PassName) \ @@ -52,6 +61,7 @@ REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(VerifySSA); REGISTER_PASS4(Inline); REGISTER_PASS2(StorageFlatten); +REGISTER_PASS2(UnrollLoop); REGISTER_PASS2(StorageSync); REGISTER_PASS4(MakeAPI); REGISTER_PASS1(SplitHostDevice); diff --git a/src/pass/inline.cc b/src/pass/inline.cc index de452c364cd8a..1dee4776e6abb 100644 --- a/src/pass/inline.cc +++ b/src/pass/inline.cc @@ -24,10 +24,24 @@ class IRInline : public IRMutator { if (op->func == f_) { CHECK_EQ(op->value_index, 0); Expr expr = body_; - CHECK_EQ(args_.size(), op->args.size()) - << op->args.size() << " vs " << args_.size(); - for (size_t i = 0; i < args_.size(); ++i) { - expr = Let::make(args_[i], op->args[i], expr); + CHECK_EQ(args_.size(), op->args.size()); + + bool has_side_effect = false; + for (size_t i = 0; i < op->args.size(); ++i) { + if (HasSideEffect(op->args[i])) has_side_effect = true; + } + + if (has_side_effect) { + for (size_t i = 0; i < args_.size(); ++i) { + expr = Let::make(args_[i], op->args[i], expr); + } + } else { + Map vmap; + for (size_t i = 0; i < args_.size(); ++i) { + vmap.Set(args_[i], op->args[i]); + } + expr = Substitute( + Evaluate::make(expr), vmap).as()->value; } return expr; } else { diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc index 0fe6b94ebd24b..5fc928cdd32b3 100644 --- a/src/pass/simple_passes.cc +++ b/src/pass/simple_passes.cc @@ -47,10 +47,10 @@ class IRSubstitue : public IRMutator { std::unordered_map smap; }; -Stmt Substitute(Stmt stmt, const Map& value_map) { +Stmt Substitute(Stmt stmt, const Map& value_map) { IRSubstitue m; for (auto kv : value_map) { - m.smap[kv.first->var.get()] = kv.second; + m.smap[kv.first.get()] = kv.second; } return m.Mutate(stmt); } diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc new file mode 100644 index 0000000000000..1374f55630fa6 --- /dev/null +++ b/src/pass/unroll_loop.cc @@ -0,0 +1,77 @@ +/*! + * Copyright (c) 2016 by Contributors + * SSA related checks and pass. + * \file ssa.cc + */ +#include +#include +#include +#include +#include +#include +#include "../schedule/compute_expr.h" + +namespace tvm { +namespace ir { + +class LoopUnroller : public IRMutator { + public: + explicit LoopUnroller(int max_auto_step) + : max_auto_step_(max_auto_step) { + } + + Stmt Mutate_(const For* op, const Stmt& s) { + Stmt stmt = s; + // constant folding. + Expr extent = ir::Simplify(op->extent); + const IntImm* v1 = extent.as(); + const UIntImm* v2 = extent.as(); + int value = -1; + if (v1 != nullptr) { + value = static_cast(v1->value); + } + if (v2 != nullptr) { + value = static_cast(v2->value); + } + bool allow_unroll = value >= 0 && value <= max_auto_step_; + if (op->for_type == ForType::Unrolled) { + CHECK_GE(value, 0) + << "Cannot unroll non-constant loop"; + allow_unroll = true; + } + + if (allow_unroll) { + if (value == 0) return Evaluate::make(0); + Stmt body = op->body; + Map vmap; + Stmt unrolled; + for (int i = 0; i < value; ++i) { + Var lv(op->loop_var.node_); + vmap.Set(lv, + schedule::ComputeExpr( + op->min, make_const(op->loop_var.type(), i))); + Stmt step = Substitute(body, vmap); + if (unrolled.defined()) { + unrolled = Block::make(unrolled, step); + } else { + unrolled = step; + } + } + return this->Mutate(unrolled); + } else { + return IRMutator::Mutate_(op, stmt); + } + } + + private: + int max_auto_step_; +}; + + +Stmt UnrollLoop(Stmt stmt, int max_auto_step) { + Stmt ret = LoopUnroller(max_auto_step).Mutate(stmt); + return ConvertSSA(ret); +} + +} // namespace ir +} // namespace tvm diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index e1390b5891f86..4e65f34f6bf72 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -230,6 +230,15 @@ MakeLoopNest(const Stage& sch, return nest; } +Stmt Substitute(Stmt s, + const std::unordered_map& value_map) { + Map temp; + for (const auto& kv : value_map) { + temp.Set(kv.first->var, kv.second); + } + return ir::Substitute(s, temp); +} + Stmt MakeLoop(const Stage& s, const Map& dom_map, Stmt provide, @@ -244,7 +253,6 @@ Stmt MakeLoop(const Stage& s, auto nest = MakeLoopNest(s, dom_map, 0, false, bound_state, {}, &value_map); - provide = Substitute(provide, value_map); if (init.defined()) { // try to find the location to insert the initialization. diff --git a/tests/python/unittest/test_pass_unroll.py b/tests/python/unittest/test_pass_unroll.py new file mode 100644 index 0000000000000..191377baaab60 --- /dev/null +++ b/tests/python/unittest/test_pass_unroll.py @@ -0,0 +1,20 @@ +import tvm + +def test_unroll_loop(): + dtype = 'int64' + n = tvm.Var('n') + Ab = tvm.Buffer((n, ), dtype) + i = tvm.Var('i') + j = tvm.Var('j') + # for i in 0 to n-1: + stmt = tvm.make.For( + i, n, 2, 0, 0, + tvm.make.For(j, 0, n, 0, 0, + tvm.make.Store(Ab.data, + tvm.make.Load(dtype, Ab.data, i) + 1, + j + 1))) + stmt = tvm.ir_pass.UnrollLoop(stmt, 8) + print(stmt) + +if __name__ == "__main__": + test_unroll_loop()