From bead881b447feed7144763e801a8ade78e8875d2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 31 Jan 2017 14:46:38 -0800 Subject: [PATCH] [SCHEDULE] Improve bound inference, support reduce codegen. --- include/tvm/expr.h | 21 +- include/tvm/ir.h | 6 +- include/tvm/ir_pass.h | 21 +- include/tvm/operation.h | 3 + include/tvm/schedule.h | 37 ++ include/tvm/schedule_pass.h | 9 + python/tvm/api.py | 36 +- python/tvm/build.py | 6 +- python/tvm/schedule.py | 8 + src/api/api_lang.cc | 6 + src/api/api_pass.cc | 1 - src/api/api_schedule.cc | 1 + src/codegen/codegen_c.cc | 5 +- src/lang/ir.cc | 10 +- src/lang/operation.cc | 11 +- src/pass/ir_mutator.cc | 8 +- src/pass/ir_visitor.cc | 2 +- src/pass/schedule_ops.cc | 334 ---------- src/pass/simple_passes.cc | 22 + src/schedule/bound.cc | 64 +- src/schedule/compute_expr.h | 109 ++++ src/schedule/int_set.cc | 593 +++++++++++------- src/schedule/int_set.h | 81 ++- src/schedule/schedule_lang.cc | 47 +- src/schedule/schedule_ops.cc | 388 ++++++++++++ tests/python/integration/test_ewise.py | 3 +- tests/python/integration/test_reduce.py | 45 ++ tests/python/unittest/test_codegen_device.py | 3 +- tests/python/unittest/test_codegen_makeapi.py | 3 +- tests/python/unittest/test_lang_tensor.py | 2 +- .../unittest/test_pass_storage_flatten.py | 2 +- ...e_ops.py => test_schedule_schedule_ops.py} | 6 +- 32 files changed, 1247 insertions(+), 646 deletions(-) delete mode 100644 src/pass/schedule_ops.cc create mode 100644 src/schedule/compute_expr.h create mode 100644 src/schedule/schedule_ops.cc create mode 100644 tests/python/integration/test_reduce.py rename tests/python/unittest/{test_pass_schedule_ops.py => test_schedule_schedule_ops.py} (89%) diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 067c8dff3b14..2c01d7acadbf 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -32,6 +32,9 @@ using Halide::Internal::IRPrinter; using Halide::Internal::Variable; using Halide::Internal::make_const; +using Halide::Internal::make_zero; +using Halide::Internal::as_const_int; +using Halide::Internal::as_const_uint; inline Type TVMType2Type(TVMType t) { @@ -126,25 +129,25 @@ using Halide::abs; using Halide::select; /*! - * \brief sum of of source expression over rdom + * \brief sum of of source expression over axis * \param source The source expression. - * \param rdom List of iteration variables that will be used for reduction. + * \param axis List of iteration variables that will be used for reduction. */ -Expr sum(Expr source, Array rdom); +Expr sum(Expr source, Array axis); /*! - * \brief max of of source expression over rdom + * \brief max of of source expression over axis * \param source The source expression. - * \param rdom List of iteration variables that will be used for reduction. + * \param axis List of iteration variables that will be used for reduction. */ -Expr max(Expr source, Array rdom); +Expr max(Expr source, Array axis); /*! - * \brief max of of source expression over rdom + * \brief max of of source expression over axis * \param source The source expression. - * \param rdom List of iteration variables that will be used for reduction. + * \param axis List of iteration variables that will be used for reduction. */ -Expr min(Expr source, Array rdom); +Expr min(Expr source, Array axis); // print functions for expr diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 6f61aced7aea..8de8615b0e06 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -30,8 +30,8 @@ struct Reduce : public ExprNode { std::string op; /*! \brief The source operand */ Expr source; - /*! \brief The reduction domains */ - Array rdom; + /*! \brief The reduction axis */ + Array axis; /*! \brief construct expr from op and rdom */ static Expr make(std::string op, Expr src, Array rdom); @@ -40,7 +40,7 @@ struct Reduce : public ExprNode { v->Visit("dtype", &type); v->Visit("op", &op); v->Visit("source", &source); - v->Visit("rdom", &rdom); + v->Visit("axis", &axis); } static const IRNodeType _type_info = IRNodeType::ExtensionExpr; static constexpr const char* _type_key = "Reduce"; diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index a2c2956a944a..fc7eab94a4cf 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -3,8 +3,8 @@ * \file ir_pass.h * \brief Collection of IR pass functions * - * All the pass functions in this file are for Stmt, - * We can use PassFunction(Evaluate(expr)) to apply it to Expr + * When the pass functions in this file are for Stmt, + * we can use PassFunction(Evaluate(expr)) to apply it to Expr */ #ifndef TVM_IR_PASS_H_ #define TVM_IR_PASS_H_ @@ -37,15 +37,6 @@ inline Stmt Simplify(Stmt a) { return Halide::Internal::simplify(a); } -/*! - * \brief Schedule s' dependent operations. - * - * \param s The schedule to be realized - * \param dom_map The domain of each iter vars. - * \return the result Stmt - */ -Stmt ScheduleOps(Schedule s, Map dom_map); - /*! * \brief verifies whether the IR stmt or Expr is in SSA form. * That is: each VarExpr is defined and assigned once(in Let/For) @@ -69,6 +60,14 @@ bool HasSideEffect(const Expr& e); */ Stmt ConvertSSA(Stmt stmt); +/*! + * \brief Substitute the var specified in key->var to be value. + * \param stmt The source statement to be substituted + * \param value_map The map of new values. + * \return The converted form. + */ +Stmt Substitute(Stmt stmt, const Map& value_map); + /*! * \brief inline all calls of f in stmt. * diff --git a/include/tvm/operation.h b/include/tvm/operation.h index aff7d9b2d637..a48d0e5b8e6e 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -49,6 +49,8 @@ class ComputeOpNode : public OperationNode { public: /*! \brief IterVar on each axis */ Array axis; + /*! \brief IterVar on each reduction axis, if the body is a Reduce */ + Array reduce_axis; /*! \brief the compute expression */ Expr body; /*! \brief constructor */ @@ -64,6 +66,7 @@ class ComputeOpNode : public OperationNode { void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); v->Visit("axis", &axis); + v->Visit("reduce_axis", &reduce_axis); v->Visit("body", &body); } static Operation make(std::string name, diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index cbb3cc81c0d3..f115dbc6f18f 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -123,6 +123,8 @@ class Stage : public NodeRef { IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner, Expr x_factor, Expr y_factor); + // declare container type + using ContainerType = StageNode; }; /*! @@ -152,11 +154,22 @@ class Schedule : public NodeRef { Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); } + /*! + * \brief Normalize the schedule. + * This is needed before bound inference. + * Insert necessary RebaseNode to make sure all leaf_iter_vars + * are in form [0, extent) + * + * \return A normalized schedule, can be same as current one. + */ + void normalize(); /*! * \brief access the internal node container * \return the pointer to the internal node container */ inline const ScheduleNode* operator->() const; + // declare container type + using ContainerType = ScheduleNode; }; /*! @@ -308,6 +321,30 @@ class FuseNode : public IterVarRelationNode { TVM_DECLARE_NODE_TYPE_INFO(FuseNode); }; +/*! + * \brief Rebase the iteration to make min to be 0. + * This is useful to normalize the Schedule + * to make every leaf variable's min to be 0. + */ +class RebaseNode : public IterVarRelationNode { + public: + /*! \brief The parent domain */ + IterVar parent; + /*! \brief The inner domain */ + IterVar rebased; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("parent", &parent); + v->Visit("rebased", &rebased); + } + + static IterVarRelation make(IterVar parent, IterVar rebased); + + static constexpr const char* _type_key = "Rebase"; + TVM_DECLARE_NODE_TYPE_INFO(RebaseNode); +}; + + // implementations inline const StageNode* Stage::operator->() const { return static_cast(node_.get()); diff --git a/include/tvm/schedule_pass.h b/include/tvm/schedule_pass.h index 45b2745c9eab..57e442c5c15e 100644 --- a/include/tvm/schedule_pass.h +++ b/include/tvm/schedule_pass.h @@ -24,6 +24,15 @@ namespace schedule { */ Map InferBound(Schedule sch); +/*! + * \brief Schedule s' dependent operations. + * + * \param s The schedule to be realized + * \param dom_map The domain of each iter vars. + * \return the result Stmt + */ +Stmt ScheduleOps(Schedule s, Map dom_map); + } // namespace schedule } // namespace tvm #endif // TVM_SCHEDULE_PASS_H_ diff --git a/python/tvm/api.py b/python/tvm/api.py index 85009186646a..bb1a563b23fa 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -212,51 +212,51 @@ def IterVar(dom=None, name=None, thread_tag=''): return _api_internal._IterVar(dom, name, thread_tag) -def sum(expr, rdom): - """Create a sum expression over rdom +def sum(expr, axis): + """Create a sum expression over axis Parameters ---------- expr : Expr The source expression. - rdom : RDomain - The reduction domainx + axis : IterVar + The reduction IterVar axis """ - rdom = rdom if isinstance(rdom, list) else [rdom] - x = _make.Reduce("Add", expr, rdom) + axis = axis if isinstance(axis, list) else [axis] + x = _make.Reduce("Add", expr, axis) return x -def min(expr, rdom): - """Create a min expression over rdom +def min(expr, axis): + """Create a min expression over axis Parameters ---------- expr : Expr The source expression. - rdom : RDomain - The reduction domainx + axis : IterVar + The reduction IterVar axis """ - rdom = rdom if isinstance(rdom, list) else [rdom] - x = _make.Reduce("Min", expr, rdom) + axis = axis if isinstance(axis, list) else [axis] + x = _make.Reduce("Min", expr, axis) return x -def max(expr, rdom): - """Create a min expression over rdom +def max(expr, axis): + """Create a min expression over axis Parameters ---------- expr : Expr The source expression. - rdom : RDomain - The reduction domainx + axis : IterVar + The reduction IterVar axis """ - rdom = rdom if isinstance(rdom, list) else [rdom] - x = _make.Reduce("Max", expr, rdom) + axis = axis if isinstance(axis, list) else [axis] + x = _make.Reduce("Max", expr, axis) return x diff --git a/python/tvm/build.py b/python/tvm/build.py index 407a1ba146aa..8839031311e9 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -62,9 +62,10 @@ def build(sch, # lowering bounds = schedule.InferBound(sch) - stmt = ir_pass.ScheduleOps(sch, bounds) + stmt = schedule.ScheduleOps(sch, bounds) stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.Simplify(stmt) + print(stmt) fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list)) fsplits = codegen.SplitHostDevice(fapi) @@ -73,7 +74,8 @@ def build(sch, for i, f in enumerate(fsplits): t = target if i >= 1 else "c" record_codes.append(codegen.CompileToC(f, output_ssa, t)) - + for c in record_codes: + print(c) if target == "cuda": ret = codegen.BuildNVRTC(fsplits, "stackvm") elif target == "opencl": diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 41a6afded977..3fd7f9730d46 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -33,6 +33,14 @@ def __getitem__(self, k): raise ValueError("Cannot find the operation %s in schedule" % (str(k))) return self.stage_map[k] + def normalize(self): + """Build a normalized schedule. + + Insert necessary rebase to make certain iter var to start from 0. + This is needed before bound inference and followup step. + """ + _api_internal._ScheduleNormalize(self) + @register_node class Stage(NodeBase): """A Stage represents schedule for one operation.""" diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index e570e55ce999..3393228f8104 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -253,4 +253,10 @@ TVM_REGISTER_API(_StageTile) *ret = Array({x_outer, y_outer, x_inner, y_inner}); }); +TVM_REGISTER_API(_ScheduleNormalize) +.set_body([](TVMArgs args, TVMRetValue* ret) { + args[0].operator Schedule() + .normalize(); + }); + } // namespace tvm diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 0a88e3b130e4..f549b6b2ee25 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -51,7 +51,6 @@ TVM_REGISTER_API(_pass_Equal) REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(VerifySSA); REGISTER_PASS4(Inline); -REGISTER_PASS2(ScheduleOps); REGISTER_PASS2(StorageFlatten); } // namespace ir diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index a84642f99efa..a4462117d494 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -29,6 +29,7 @@ namespace schedule { REGISTER_SCHEDULE_PASS1(InferBound); REGISTER_SCHEDULE_PASS1(CreateReadGraph); REGISTER_SCHEDULE_PASS2(PostDFSOrder); +REGISTER_SCHEDULE_PASS2(ScheduleOps); } // namespace schedule } // namespace tvm diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index eade8577f2ea..737cdc18bd7a 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -2,6 +2,7 @@ * Copyright (c) 2017 by Contributors * \file codegen_c.cc */ +#include #include "./codegen_c.h" namespace tvm { @@ -216,7 +217,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N switch (op->type.bits()) { case 64: case 32: { std::ostringstream temp; - temp << op->value; + temp << std::scientific << op->value; if (op->type.bits() == 32) temp << 'f'; p->MarkConst(temp.str()); os << temp.str(); @@ -225,7 +226,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N case 16: { os << '('; p->PrintType(op->type, os); - os << ')' << op->value << 'f'; + os << ')' << std::scientific <value << 'f'; break; } default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n"; diff --git a/src/lang/ir.cc b/src/lang/ir.cc index fbdb34ac6680..9a638b44f1ac 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -26,7 +26,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) << op->op << ", "; p->print(op->source); - p->stream << ", rdom=" << op->rdom << ")"; + p->stream << ", axis=" << op->axis << ")"; }); } // namespace Internal @@ -35,16 +35,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) namespace tvm { namespace ir { -Expr Reduce::make(std::string op, Expr source, Array rdom) { +Expr Reduce::make(std::string op, Expr source, Array axis) { auto n = std::make_shared(); CHECK(source.defined()); - for (size_t i = 0; i < rdom.size(); ++i) { - CHECK(rdom[i].defined()); + for (size_t i = 0; i < axis.size(); ++i) { + CHECK(axis[i].defined()); } n->type = source.type(); n->source = source; n->op = op; - n->rdom = rdom; + n->axis = axis; return Expr(n); } diff --git a/src/lang/operation.cc b/src/lang/operation.cc index ce26e65da8fe..95c292e48dd2 100644 --- a/src/lang/operation.cc +++ b/src/lang/operation.cc @@ -4,6 +4,7 @@ */ #include #include +#include #include namespace tvm { @@ -57,7 +58,12 @@ Tensor Placeholder(Array shape, Type dtype, std::string name) { // ComputeOpNode Array ComputeOpNode::root_iter_vars() const { - return axis; + if (reduce_axis.size() == 0) return axis; + Array ret = axis; + for (IterVar iv : reduce_axis) { + ret.push_back(iv); + } + return ret; } Type ComputeOpNode::output_dtype(size_t i) const { @@ -101,6 +107,9 @@ Operation ComputeOpNode::make(std::string name, n->name = name; n->axis = axis; n->body = body; + if (n->body->is_type()) { + n->reduce_axis = n->body.as()->axis; + } return Operation(n); } diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index 85b0589ce60c..72f118a4667f 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -37,7 +37,7 @@ inline Array MutateArray(Array arr, IRMutator *m) { } } -inline Array MutateRDom(Array rdom, IRMutator *m) { +inline Array MutateIterVarArr(Array rdom, IRMutator *m) { std::vector new_dom(rdom.size()); bool changed = false; for (size_t i = 0; i < rdom.size(); i++) { @@ -237,13 +237,13 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) { TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .set_dispatch([](const Reduce* op, const Expr& e, IRMutator* m) { - Array new_rdom = MutateRDom(op->rdom, m); + Array new_axis = MutateIterVarArr(op->axis, m); Expr new_source = m->Mutate(op->source); - if (op->rdom.same_as(new_rdom) && + if (op->axis.same_as(new_axis) && op->source.same_as(new_source)) { return e; } else { - return Reduce::make(op->op, new_source, new_rdom); + return Reduce::make(op->op, new_source, new_axis); } }); diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index 77ce3928f2fe..4b8b005ddea5 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -120,7 +120,7 @@ void IRVisitor::Visit_(const Call *op) { TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .set_dispatch([](const Reduce* op, IRVisitor* v) { - VisitRDom(op->rdom, v); + VisitRDom(op->axis, v); v->Visit(op->source); }) .set_dispatch(NoOp) diff --git a/src/pass/schedule_ops.cc b/src/pass/schedule_ops.cc deleted file mode 100644 index c2332a819609..000000000000 --- a/src/pass/schedule_ops.cc +++ /dev/null @@ -1,334 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file schedule_ops.cc - */ -#include -#include -#include -#include -#include - -#include "./scope.h" -#include "./ir_util.h" -#include "../schedule/graph.h" - -namespace tvm { -namespace ir { - -/*! - * \brief use message passing to calculate the assignment of each Var inside the loop body. - * \param s The schedule to be used. - * \param dom_map The domain map of each iteration variable's domain - * \param p_state The message passing state - * IterVar->The assignment. - */ -void PassUpOffset(const Stage& s, - const Map& dom_map, - std::unordered_map* p_state) { - auto& state = *p_state; - for (size_t i = s->relations.size(); i != 0; --i) { - IterVarRelation rel = s->relations[i - 1]; - if (rel.as()) { - const SplitNode* s = rel.as(); - Expr outer = state.at(s->outer); - Expr inner = state.at(s->inner); - Expr factor = dom_map.at(s->inner)->extent; - Expr parent_min = dom_map.at(s->parent)->min; - state[s->parent] = inner + outer * factor; - // add min if they exist - if (!is_zero(parent_min)) { - state[s->parent] = parent_min + state[s->parent]; - } - } else if (rel.as()) { - const FuseNode* s = rel.as(); - Expr value = state.at(s->fused); - Expr factor = dom_map.at(s->inner)->extent; - Expr outer_min = dom_map.at(s->outer)->min; - Expr inner_min = dom_map.at(s->inner)->min; - state[s->outer] = value / factor; - state[s->inner] = value % factor; - // add min if they exist - if (!is_zero(outer_min)) { - state[s->outer] = outer_min + state[s->outer]; - } - if (!is_zero(inner_min)) { - state[s->inner] = outer_min + state[s->inner]; - } - } else { - LOG(FATAL) << "unknown relation type"; - } - } -} - -/*! - * \brief split the expr by addition. - * \param expr The expression to be splitted. - * \param loop_level The loop level of each Variable - * \param result vector of (level, expr) - * The level gives the mimimum loop level this expression need to be computed. - * The Expr gives the expression content. - */ -void SplitByAdd(Expr expr, - const std::unordered_map& loop_level, - std::vector > *result) { - const Add* op = expr.as(); - if (op != nullptr) { - SplitByAdd(op->a, loop_level, result); - SplitByAdd(op->b, loop_level, result); - } else { - size_t max_level = 0; - auto fvisit = [&max_level, &loop_level](const NodeRef& n) { - const Variable* op = n.as(); - if (op != nullptr) { - auto it = loop_level.find(op); - if (it != loop_level.end()) { - max_level = std::max(max_level, it->second); - } - } - }; - PostOrderVisit(expr, fvisit); - result->push_back(std::make_pair(max_level, expr)); - } -} - -/*! - * \brief Make the loop nest of the correspondings schedule. - * \param sch The schedule. - * \param dom_map The domain map. - * - * \return a nested representation of loop statements. - * The flattened Stmt are ordered from outmost to inner most order. - */ -std::vector > MakeLoopNest( - const Stage& sch, - const Map& dom_map) { - // optional, use let to define some CSE in dom_map. - auto leaf_iter_vars = sch->leaf_iter_vars; - std::unordered_map offset; - std::unordered_map loop_level; - Stmt no_op = Evaluate::make(0); - // create the loop nest - std::vector > nest; - nest.resize(leaf_iter_vars.size() + 1); - - for (size_t i = 0; i < leaf_iter_vars.size(); ++i) { - auto iv = leaf_iter_vars[i]; - Range dom = dom_map.at(iv); - // initialize the offset and loop_level - offset[iv] = iv->var; - loop_level[iv->var.as()] = i + 1; - // Mark the iter var in the IR, to remember the point - if (iv->thread_tag.length() == 0) { - if (is_zero(dom->min)) { - nest[i + 1].emplace_back( - For::make(iv->var, 0, dom->extent, - ForType::Serial, DeviceAPI::None, no_op)); - } else { - Var idx(iv->var->name_hint + ".idx", iv->var.type()); - nest[i + 1].emplace_back( - For::make(idx, 0, dom->extent, - ForType::Serial, DeviceAPI::None, no_op)); - nest[i + 1].emplace_back( - LetStmt::make(iv->var, dom->min + idx, no_op)); - } - } else { - // Always restrict threaded IterVar to starts from 0. - CHECK(is_zero(dom->min)); - // annotate the extent of the IterVar - nest[i + 1].emplace_back( - AttrStmt::make(iv, "thread_extent", dom->extent, no_op)); - } - // annotate the extent of the IterVar - nest[i + 1].emplace_back( - AttrStmt::make(iv, "scope", iv->var, no_op)); - } - // message passing to get offset of root iter vars. - PassUpOffset(sch, dom_map, &offset); - - for (IterVar iv : sch->op->root_iter_vars()) { - Expr value = offset.at(iv); - if (!value.same_as(iv->var)) { - using Entry = std::pair; - std::vector splits; - SplitByAdd(value, loop_level, &splits); - - Expr offset = 0; - size_t nsplit_left = splits.size() - 1; - for (size_t i = 0; i <= leaf_iter_vars.size(); ++i) { - size_t hit = 0; - for (const auto& kv : splits) { - if (kv.first == i) { - if (is_zero(offset)) { - offset = kv.second; - } else { - offset = offset + kv.second; - ++hit; - } - } - } - nsplit_left -= hit; - if (hit != 0) { - std::ostringstream os; - os << iv->var->name_hint << ".at.l" << i; - Var base_offset(os.str()); - if (nsplit_left == 0) { - base_offset = iv->var; - } - nest[i].emplace_back( - LetStmt::make(base_offset, offset, no_op)); - offset = base_offset; - } - } - Range dom = dom_map.at(iv); - if (!offset.same_as(iv->var)) { - // define the iv->var - nest.back().emplace_back( - LetStmt::make(iv->var, offset, no_op)); - } - Expr condition = (iv->var - dom->min) < dom->extent; - // Boundary condition checking - // Need better boundary condition here. - nest.back().emplace_back(IfThenElse::make(condition, no_op)); - } - } - return nest; -} - - -/*! - * \brief Make pipeline specifically for compute op node. - * \param op The compute node - * \param tensors The tensors generated by provide. - */ -Stmt MakeProvide(const ComputeOpNode* op, - const std::vector& tensors) { - Tensor t = tensors[0]; - Array args; - for (IterVar iv : op->axis) { - args.push_back(iv->var); - } - return Provide::make(t->op, t->value_index, op->body, args); -} - -/*! - * \brief Make pipeline specifically for compute op node. - * \param op The compute node - * \param dom_map The domain map - * \param tensors The tensors generated by provide. - * \param body The content of the pipeline. - */ -Stmt MakeRealize(const ComputeOpNode* op, - const Map& dom_map, - const std::vector& tensors, - Stmt body) { - Tensor t = tensors[0]; - Halide::Internal::Region bounds; - for (IterVar iv : op->axis) { - bounds.push_back(dom_map.at(iv)); - } - return Realize::make(t->op, t->value_index, t->dtype, - bounds, make_const(Bool(1), true), body); -} - -Stmt MakePipeline(const Stage& sch, - const Map& dom_map, - Stmt consumer) { - std::vector tensors; - for (int i = 0; i < sch->op->num_outputs(); ++i) { - tensors.emplace_back(sch->op.output(i)); - } - - Stmt provide; - if (sch->op.as()) { - provide = MakeProvide(sch->op.as(), tensors); - } else { - LOG(FATAL) << "not supported op " << sch->op->type_key(); - } - std::vector > nest = MakeLoopNest(sch, dom_map); - Stmt producer = MergeNest(nest, provide); - producer = ProducerConsumer::make(sch->op, true, producer); - - Stmt pipeline = producer; - if (consumer.defined()) { - consumer = ProducerConsumer::make(sch->op, false, consumer); - pipeline = Block::make(producer, consumer); - } - - if (sch->op.as()) { - return MakeRealize(sch->op.as(), - dom_map, tensors, pipeline); - } else { - LOG(FATAL) << "not supported op"; - return Stmt(); - } -} - -// inject the operator's realization on the stmt. -class InjectRealize : public IRMutator { - public: - InjectRealize(Stage schedule, Map dom_map) - : schedule(schedule), dom_map(dom_map) {} - - Stmt Mutate(Stmt stmt) final { - CHECK(stmt.defined()); - stmt = IRMutator::Mutate(stmt); - const AttrStmt* op = stmt.as(); - if (op != nullptr && - op->type_key == "scope") { - if (op->node == schedule->attach_ivar) { - CHECK(!found_attach); - found_attach = true; - stmt = AttrStmt::make( - op->node, op->type_key, op->value, - MakePipeline(schedule, dom_map, - IRMutator::Mutate(op->body))); - } - } - return stmt; - } - // the operations to be carried - Stage schedule; - // domain map - Map dom_map; - // whether attach point is found - bool found_attach{false}; -}; - -Stmt InjectInline(const Operation op, Stmt body) { - CHECK(body.defined()); - const ComputeOpNode* compute = op.as(); - CHECK(compute != nullptr) - << "can only inline compute op"; - Array args; - for (auto iv : compute->axis) { - args.push_back(iv->var); - } - return Inline(body, op, args, compute->body); -} - - -Stmt ScheduleOps( - Schedule sch, Map dom_map) { - Stmt body = Stmt(); - // reverse the post DFS order. - for (size_t i = sch->stages.size(); i != 0; --i) { - Stage s = sch->stages[i - 1]; - // no need to specify place holder op. - if (s->op.as()) continue; - if (s->attach_type == kInline) { - body = InjectInline(s->op, body); - } else if (s->attach_type == kRoot || s-> attach_type == kNone) { - body = MakePipeline(s, dom_map, body); - } else if (s->attach_type == kScope) { - CHECK(body.defined()); - InjectRealize mutator(s, dom_map); - body = mutator.Mutate(body); - CHECK(mutator.found_attach) - << "did not find attachment point"; - } - } - return body; -} - -} // namespace ir -} // namespace tvm diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc index 38939459722b..0fe6b94ebd24 100644 --- a/src/pass/simple_passes.cc +++ b/src/pass/simple_passes.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include namespace tvm { @@ -32,5 +33,26 @@ bool HasSideEffect(const Expr& e) { v.Visit(e); return v.has_side_effect_; } + +class IRSubstitue : public IRMutator { + public: + Expr Mutate_(const Variable* op, const Expr& e) final { + auto it = smap.find(op); + if (it != smap.end()) { + return it->second; + } else { + return e; + } + } + std::unordered_map smap; +}; + +Stmt Substitute(Stmt stmt, const Map& value_map) { + IRSubstitue m; + for (auto kv : value_map) { + m.smap[kv.first->var.get()] = kv.second; + } + return m.Mutate(stmt); +} } // namespace ir } // namespace tvm diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index d4ce520c9229..36532aa419d7 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -54,6 +54,11 @@ void PassDown(const Stage& s, const Range& range_inner = state.at(r->inner); state[r->fused] = Range::make_with_min_extent( 0, range_outer->extent * range_inner->extent); + } else if (rel.as()) { + const RebaseNode* r = rel.as(); + CHECK(state.count(r->parent)); + state[r->rebased] = Range::make_with_min_extent( + 0, state.at(r->parent)->extent); } else { LOG(FATAL) << "unknown relation type"; } @@ -85,6 +90,13 @@ void PassUp(const Stage& s, &outer, &inner); state[r->outer] = outer; state[r->inner] = inner; + } else if (rel.as()) { + IntSet parent; + const RebaseNode* r = rel.as(); + PassUp(r, dom_map, + state.at(r->rebased), + &parent); + state[r->parent] = parent; } else { LOG(FATAL) << "unknown relation type"; } @@ -109,9 +121,15 @@ void PassToOperation( // Eventually, we need to change the inference to be a Pull style inference if (tensor->op.as()) { auto root_iter_vars = tensor->op->root_iter_vars(); - CHECK_EQ(tensor.ndim(), root_iter_vars.size()); - for (size_t i = 0; i < tensor.ndim(); ++i) { - (*result)[root_iter_vars[i]].push_back(dim_bounds[i]); + const ComputeOpNode* op = tensor->op.as(); + CHECK_EQ(op->axis.size() + op->reduce_axis.size(), root_iter_vars.size()); + for (size_t i = 0; i < op->axis.size(); ++i) { + (*result)[op->axis[i]].push_back(dim_bounds[i]); + } + // reduction. + for (size_t i = 0; i < op->reduce_axis.size(); ++i) { + (*result)[op->reduce_axis[i]].push_back( + IntSet::range(op->reduce_axis[i]->dom)); } } else { LOG(FATAL) << "unknown operation mode " << tensor->op->type_key(); @@ -173,9 +191,9 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) { {"local", 2} }; static std::unordered_map thread_tag_rank{ - {"gridIdx.x", 0}, - {"gridIdx.y", 0}, - {"gridIdx.z", 0}, + {"blockIdx.x", 0}, + {"blockIdx.y", 0}, + {"blockIdx.z", 0}, {"threadIdx.x", 1}, {"threadIdx.y", 1}, {"threadIdx.z", 1} @@ -194,8 +212,6 @@ void InferBound(const Stage& stage, (*rmap)[iv] = iv->dom; } } - // get range of all child iter vars. - PassDown(stage, rmap); if (stage->attach_type == kScope) { Stage parent = stage->attach_stage; @@ -206,10 +222,18 @@ void InferBound(const Stage& stage, bool fix_value = true; for (auto iv : parent->leaf_iter_vars) { + Range vrange = rmap->at(iv); + CHECK(is_zero(vrange->min)) + << "InferBound requires every leaf iter var's min equals 0, " + << "call schedule.normalize to achieve this."; + // special optimization to remove trivial loop + if (is_one(vrange->extent)) { + up_state[iv] = IntSet::single_point(vrange->min); + } if (fix_value && !ScopeRelax(iv, stage->scope)) { - up_state[iv] = IntSet::make_point(iv->var); + up_state[iv] = IntSet::single_point(iv->var); } else { - up_state[iv] = IntSet::make_range(rmap->at(iv)); + up_state[iv] = IntSet::range(vrange); } if (stage->attach_ivar == iv) { fix_value = false; @@ -223,12 +247,30 @@ void InferBound(const Stage& stage, bp_state[iv] = {up_state.at(iv)}; } auto result = BoundProp(post_order, &bp_state); + + // Set relaxation + Map relax_set; + Stage s = stage; + while (s->attach_type == kScope) { + s = s->attach_stage; + for (auto iv : s->leaf_iter_vars) { + if (ScopeRelax(iv, stage->scope)) { + relax_set.Set(iv, IntSet::range(rmap->at(iv))); + } + } + } for (auto iv : stage->op->root_iter_vars()) { CHECK(result.count(iv)); CHECK(!rmap->count(iv)); - (*rmap)[iv] = result.at(iv).GetCoverRange(); + Range r = result.at(iv).cover_range(iv->dom); + if (relax_set.size() != 0) { + r = EvalSet(r, relax_set).cover_range(iv->dom); + } + (*rmap)[iv] = r; } } + // get range of all child iter vars. + PassDown(stage, rmap); } diff --git a/src/schedule/compute_expr.h b/src/schedule/compute_expr.h new file mode 100644 index 000000000000..0feb582fcec2 --- /dev/null +++ b/src/schedule/compute_expr.h @@ -0,0 +1,109 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file compute_expr.h + * \brief Utility integer expression with quick eager simplification. + * This is weaker than Simplify but can be done Eagerly. + */ +#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_ +#define TVM_SCHEDULE_COMPUTE_EXPR_H_ + +#include +#include + +namespace tvm { +namespace schedule { + +using Halide::Internal::add_would_overflow; +using Halide::Internal::sub_would_overflow; +using Halide::Internal::mul_would_overflow; + +/*! + * \brief Compute the expression with the given binary op. + * \param lhs The left operand + * \param rhs The right operand + * \return The result. + */ +template +inline Expr ComputeExpr(Expr lhs, Expr rhs) { + return OP::make(lhs, rhs); +} + +template +inline bool GetConst(Expr e, T* out); + +template<> +bool GetConst(Expr e, int64_t *out) { + if (e.type().is_vector()) return false; + const int64_t *v = as_const_int(e); + if (v) { + *out = *v; return true; + } else { + return false; + } +} +template<> +bool GetConst(Expr e, uint64_t *out) { + if (e.type().is_vector()) return false; + const uint64_t *v = as_const_uint(e); + if (v) { + *out = *v; return true; + } else { + return false; + } +} + +#define TVM_CONST_PROPAGATION(OP_NAME, OP) \ + int64_t ia = 0, ib = 0; \ + if (GetConst(a, &ia) && GetConst(b, &ib)) { \ + if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) { \ + LOG(FATAL) << "signed int overflow"; \ + } \ + return ir::IntImm::make(a.type(), ia OP ib); \ + } \ + uint64_t ua = 0, ub = 0; \ + if (GetConst(a, &ua) && GetConst(b, &ub)) { \ + return ir::UIntImm::make(a.type(), ua + ub); \ + } \ + +template<> +inline Expr ComputeExpr(Expr a, Expr b) { + if (is_zero(a)) return b; + if (is_zero(b)) return a; + TVM_CONST_PROPAGATION(add, +); + return ir::Add::make(a, b); +} + +template<> +inline Expr ComputeExpr(Expr a, Expr b) { + if (is_zero(b)) return a; + TVM_CONST_PROPAGATION(sub, -); + return ir::Add::make(a, b); +} + +template<> +inline Expr ComputeExpr(Expr a, Expr b) { + if (is_one(a)) return b; + if (is_one(b)) return a; + TVM_CONST_PROPAGATION(mul, *); + return ir::Mul::make(a, b); +} + +template<> +inline Expr ComputeExpr(Expr a, Expr b) { + if (is_one(b)) return a; + return ir::Mul::make(a, b); +} + +template<> +inline Expr ComputeExpr(Expr a, Expr b) { + return Halide::Internal::Interval::make_max(a, b); +} + +template<> +inline Expr ComputeExpr(Expr a, Expr b) { + return Halide::Internal::Interval::make_min(a, b); +} + +} // namespace schedule +} // namespace tvm +#endif // TVM_SCHEDULE_COMPUTE_EXPR_H_ diff --git a/src/schedule/int_set.cc b/src/schedule/int_set.cc index ac0b0c6ac910..0da1a39e7c60 100644 --- a/src/schedule/int_set.cc +++ b/src/schedule/int_set.cc @@ -1,212 +1,355 @@ /*! * Copyright (c) 2016 by Contributors - * \file int_set.cc + * \file int_set_impl.cc * \brief The integer set functions */ #include +#include +#include #include "./int_set.h" +#include "./compute_expr.h" namespace tvm { namespace schedule { +using Halide::Internal::Interval; + using namespace ir; +/*! \brief Set of continuous interval */ +struct IntervalSet : public IntSetNode { + /*! \brief the internal interval*/ + Interval i; + + static IntSet make(Interval i) { + std::shared_ptr n = + std::make_shared(); + n->i = i; + return IntSet(n); + } + static IntSet make(Expr min, Expr max) { + std::shared_ptr n = + std::make_shared(); + n->i.min = min; + n->i.max = max; + return IntSet(n); + } + + static constexpr const char* _type_key = "IntervalSet"; + TVM_DECLARE_NODE_TYPE_INFO(IntervalSet); +}; + /*! - * \brief Internal node container of int set. + * \brief set represented by strided integers + * Reserved for cases where strided access is supported. */ -class IntSetNode : public Node { - public: - /*! \brief The base range scope */ - Range base; - /*! \brief additional strided domain */ - Array domain; - /*! \brief The stride of each strided domain */ - Array stride; - /*! - * \brief The concrete set, - * used when concrete execution is enabled. - */ - std::vector concrete; - - void VisitAttrs(AttrVisitor* v) final { - v->Visit("base", &base); - v->Visit("domain", &domain); - v->Visit("stride", &stride); - } - - static constexpr const char* _type_key = "IntSet"; - TVM_DECLARE_NODE_TYPE_INFO(IntSetNode); +struct StrideSet : public IntSetNode { + /*! \brief the base inetrval */ + Interval base; + /*! \brief additional extents in positive number */ + Array extents; + /*! \brief additional strides in positive number */ + Array strides; + + static constexpr const char* _type_key = "StrideSet"; + TVM_DECLARE_NODE_TYPE_INFO(StrideSet); }; -TVM_REGISTER_NODE_TYPE(IntSetNode); +inline IntSet IntSet::cover_interval() const { + if ((*this).as()) return *this; + const StrideSet* s = (*this).as(); + if (s) { + CHECK_NE(s->extents.size(), 0U); + Expr max = s->base.max; + for (size_t i = 0; i < s->extents.size(); ++i) { + max = max + s->extents[i] * s->strides[i] - s->strides[i]; + } + return IntervalSet::make(s->base.min, max); + } + LOG(FATAL) << "cannot convert set " << (*this)->type_key() << " to interval"; + return IntSet::everything(); +} + +Range IntSet::cover_range(Range max_range) const { + IntSet temp; + const IntervalSet* s_int = (*this).as(); + if (s_int == nullptr) { + temp = this->cover_interval(); + s_int = temp.as(); + } + if (s_int->i.is_bounded()) { + return Range::make_with_min_extent( + s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min)); + } + return max_range; +} -namespace { +bool IntSet::is_everything() const { + const IntervalSet* s_int = (*this).as(); + return (s_int && s_int->i.is_everything()); +} -inline bool Match(const Expr& e, int64_t value) { - const ir::IntImm* v = e.as(); - return v != nullptr && v->value; +bool IntSet::is_single_point() const { + const IntervalSet* s_int = (*this).as(); + return (s_int && s_int->i.is_single_point()); } -// whether a exactly matches b. -inline bool Match(const IntSet& a, - const Range& b) { - if (a->base == b && - a->domain.size() == 0 && - a->concrete.size() == 0) { - return true; - } else { - return false; - } +IntSet IntSet::everything() { + return IntervalSet::make(Interval::everything()); } -// whether a exactly matches b. -inline bool Match(const IntSet& a, - const Expr& b) { - if (a->domain.size() == 0 && - a->concrete.size() == 0) { - return Match(a->base->extent, 1) && a->base->min.same_as(b); - } else { - return false; - } +IntSet IntSet::single_point(Expr x) { + return IntervalSet::make(Interval::single_point(x)); } -inline bool IsNumber(const IntSet& s) { - if (s->domain.size() != 0) return false; - if (s->concrete.size() != 0) { - return s->concrete.size() == 1; +IntSet IntSet::range(Range r) { + // must make sure it can be matched back by MatchRange. + if (is_one(r->extent)) { + return IntSet::single_point(r->min); + } + if (is_positive_const(r->extent) && is_const(r->min)) { + return IntervalSet::make( + r->min, ComputeExpr(ComputeExpr(r->extent, r->min), 1)); } - return Match(s->base->extent, 1); + return IntervalSet::make(r->min, (r->extent + r->min) - 1); } -inline Expr AsNumber(const IntSet& s) { - return s->base->min; +// Check if a is created from b. +inline bool MatchRange(const IntSet& a, + const Range& b) { + const IntervalSet* a_int = a.as(); + if (!a_int) return false; + const Interval& i = a_int->i; + if (!i.min.same_as(b)) return false; + if (is_one(b->extent)) return i.is_single_point(); + if (is_positive_const(b->extent) && is_const(b->min)) { + // deep equality + return Equal( + ComputeExpr(ComputeExpr(b->extent, b->min), 1), + a_int->i.max); + } + const Sub* sub = i.max.as(); + if (!sub) return false; + if (is_one(sub->b)) return false; + const Add* add = sub->a.as(); + return add && + add->a.same_as(b->min) && + add->b.same_as(b->extent); } -// set combination rule by operators -template -inline IntSet BinaryCombine(IntSet a, IntSet b) { - LOG(WARNING) << "cannot evaluate binary op " << T::_type_key; - return IntSet::make_all_set(); +inline bool MatchPoint(const IntSet& a, + const Expr& b) { + const IntervalSet* a_int = a.as(); + if (!a_int) return false; + const Interval& i = a_int->i; + return i.is_single_point() && i.min.same_as(b); } -template<> -inline IntSet BinaryCombine(IntSet a, IntSet b) { - auto n = std::make_shared(*(a.operator->())); - for (size_t i = 0; i < b->domain.size(); ++i) { - n->domain.push_back(b->domain[i]); - n->stride.push_back(b->stride[i]); - } - - if (IsNumber(a)) { - n->base = Range::make_with_min_extent( - a->base->min + b->base->min, - b->base->extent); - } else if (IsNumber(b)) { - n->base = Range::make_with_min_extent( - a->base->min + b->base->min, - a->base->extent); - } else { - n->base = Range::make_with_min_extent( - a->base->min + b->base->min, - a->base->extent + b->base->extent - 1); +IntSet Union(const Array& set) { + if (set.size() == 1) return set[0]; + Interval x = set[0].cover_interval().as()->i; + for (size_t i = 1; i < set.size(); ++i) { + x.include(set[i].cover_interval().as()->i); } - return IntSet(n); + return IntervalSet::make(x); } -inline Range Negation(Range a) { - if (Match(a->extent, 1)) { - return Range::make_with_min_extent(-a->min, a->extent); - } else { - return Range::make_with_min_extent(-(a->min + a->extent - 1), a->extent); +// type traits +template +struct is_logical_op { + static const bool value = false; +}; + +#define TVM_DECLARE_LOGICAL_OP(OP) \ + template<> \ + struct is_logical_op { \ + static const bool value = true; \ + }; + +// interval related. +template +inline IntSet CombineInterval(Interval a, Interval b) { + if (a.is_single_point() && b.is_single_point()) { + return IntSet::single_point(ComputeExpr(a.min, b.min)); } + LOG(WARNING) << "Return Everything in CombineInterval " << OP::_type_key; + return IntSet::everything(); } -inline IntSet Negation(IntSet a) { - CHECK_EQ(a->concrete.size(), 0U); - auto n = std::make_shared(); - n->base = Negation(a->base); - for (size_t i = 0; i < a->domain.size(); ++i) { - n->domain.push_back(Negation(a->domain[i])); - n->stride.push_back(a->stride[i]); +template<> +inline IntSet CombineInterval(Interval a, Interval b) { + if (a.is_single_point() && b.is_single_point()) { + return IntSet::single_point(ComputeExpr(a.min, b.min)); + } + Interval r = Interval::everything(); + if (a.has_lower_bound() && b.has_lower_bound()) { + r.min = ComputeExpr(a.min, b.min); } - return IntSet(a); + if (a.has_upper_bound() && b.has_upper_bound()) { + r.max = ComputeExpr(a.max, b.max); + } + return IntervalSet::make(r); } template<> -inline IntSet BinaryCombine(IntSet a, IntSet b) { - return BinaryCombine(a, Negation(b)); +inline IntSet CombineInterval(Interval a, Interval b) { + if (a.is_single_point() && b.is_single_point()) { + return IntSet::single_point(ComputeExpr(a.min, b.min)); + } + Interval r = Interval::everything(); + if (a.has_lower_bound() && b.has_upper_bound()) { + r.min = ComputeExpr(a.min, b.max); + } + if (a.has_upper_bound() && b.has_lower_bound()) { + r.max = ComputeExpr(a.max, b.min); + } + return IntervalSet::make(r); } -inline IntSet BinaryMul(IntSet a, Expr b) { - // copy construct - if (Match(b, 1)) return a; - if (Match(b, -1)) return Negation(a); - auto n = std::make_shared(); - n->base = Range::make_with_min_extent(0, 1); - n->domain.push_back(a->base); - n->stride.push_back(b); - for (size_t i = 0; i < a->domain.size(); ++i) { - n->domain.push_back(a->domain[i]); - n->stride.push_back(a->stride[i] * b); - } - return IntSet(a); +template<> +inline IntSet CombineInterval(Interval a, Interval b) { + if (a.is_single_point() && b.is_single_point()) { + return IntSet::single_point(ComputeExpr(a.min, b.min)); + } + if (a.is_single_point() && !b.is_single_point()) { + std::swap(a, b); + } + if (b.is_single_point()) { + if (is_zero(b.min)) return IntSet::single_point(0); + if (is_one(b.min)) return IntervalSet::make(a); + Expr e1 = a.has_lower_bound() ? ComputeExpr(a.min, b.min) : a.min; + Expr e2 = a.has_upper_bound() ? ComputeExpr(a.max, b.min) : a.max; + // This is relaxiation + // TODO(tqchen): consider convert to StrideSet. + if (is_positive_const(b.min)) { + return IntervalSet::make(e1, e2); + } else if (is_negative_const(b.min)) { + return IntervalSet::make(e2, e1); + } else if (a.is_bounded()) { + Expr cmp = b.min >= make_zero(b.min.type().element_of()); + return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1)); + } + } + LOG(WARNING) << "Return Everything in CombineInterval Mul"; + return IntSet::everything(); } template<> -inline IntSet BinaryCombine(IntSet a, IntSet b) { - if (IsNumber(a)) { - return BinaryMul(a, AsNumber(b)); - } else if (IsNumber(b)) { - return BinaryMul(b, AsNumber(a)); - } else { - return IntSet::make_all_set(); +inline IntSet CombineInterval(Interval a, Interval b) { + if (a.is_single_point() && b.is_single_point()) { + return IntSet::single_point(ComputeExpr(a.min, b.min)); } + return IntervalSet::make(Interval::make_max(a.min, b.min), + Interval::make_max(a.max, b.max)); } -} // namespace - -inline const IntSetNode* IntSet::operator->() const { - return static_cast(node_.get()); +template<> +inline IntSet CombineInterval(Interval a, Interval b) { + if (a.is_single_point() && b.is_single_point()) { + return IntSet::single_point(ComputeExpr(a.min, b.min)); + } + return IntervalSet::make(Interval::make_min(a.min, b.min), + Interval::make_min(a.max, b.max)); } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const IntSetNode *op, IRPrinter *p) { - p->stream << "int-set(base="; - p->print(op->base); - p->stream << ')'; - }); +template +inline IntSet CombineInterval_(IntSet a, IntSet b) { + return CombineInterval( + a.as()->i, b.as()->i); +} -IntSet IntSet::make_range(Range dom) { - auto n = std::make_shared(); - n->base = dom; +// stride related +inline IntSet AsStrideSet(IntSet a) { + if (a.as()) return a; + const IntervalSet* s = a.as(); + CHECK(s->i.is_bounded()); + std::shared_ptr n = std::make_shared(); + n->base = s->i; return IntSet(n); } +template +inline IntSet CombineSets(IntSet a, IntSet b) { + return CombineInterval_(a.cover_interval(), b.cover_interval()); +} -Range IntSet::GetCoverRange() const { - const IntSetNode* s = operator->(); - CHECK(s != nullptr) << "empty set"; - if (s->domain.size() == 0 && s->concrete.size() == 0) { - return s->base; +template<> +inline IntSet CombineSets(IntSet a, IntSet b) { + const IntervalSet* a_int = a.as(); + const IntervalSet* b_int = b.as(); + if (a_int && is_zero(a_int->i.min)) return b; + if (b_int && is_zero(b_int->i.min)) return a; + a = AsStrideSet(a); + b = AsStrideSet(b); + const StrideSet* a_stride = a.as(); + const StrideSet* b_stride = b.as(); + auto n = std::make_shared(*a_stride); + for (size_t i = 0; i < b_stride->extents.size(); ++i) { + n->extents.push_back(b_stride->extents[i]); + n->strides.push_back(b_stride->strides[i]); } - LOG(FATAL) << "not yet implemented"; - return Range(); + n->base = CombineInterval( + a_stride->base, b_stride->base).as()->i; + return IntSet(n); } -IntSet IntSet::make_point(Expr point) { - return IntSet::make_range(Range::make_with_min_extent(point, 1)); +inline IntSet NegateSet(IntSet a) { + const IntervalSet* a_int = a.as(); + if (a_int) { + if (a_int->i.is_single_point()) { + return IntSet::single_point(-a_int->i.min); + } else { + Interval r = Interval::everything(); + if (a_int->i.has_upper_bound()) { + r.min = -(a_int->i.max); + } + if (a_int->i.has_lower_bound()) { + r.max = -(a_int->i.min); + } + return IntervalSet::make(r); + } + } else { + return NegateSet(a.cover_interval()); + } } -IntSet IntSet::make_all_set() { - LOG(FATAL) << "TODO"; - return IntSet(); +template<> +inline IntSet CombineSets(IntSet a, IntSet b) { + return CombineSets(a, NegateSet(b)); } -IntSet Union(const Array& set) { - if (set.size() == 1) return set[0]; - LOG(FATAL) << "TODO"; - return IntSet(); +TVM_DECLARE_LOGICAL_OP(And); +TVM_DECLARE_LOGICAL_OP(Or); +TVM_DECLARE_LOGICAL_OP(EQ); +TVM_DECLARE_LOGICAL_OP(NE); +TVM_DECLARE_LOGICAL_OP(GE); +TVM_DECLARE_LOGICAL_OP(GT); +TVM_DECLARE_LOGICAL_OP(LE); +TVM_DECLARE_LOGICAL_OP(LT); +TVM_DECLARE_LOGICAL_OP(Not); + +// generic combine operations of two sets +template +inline IntSet Combine(const IntSet& a, const IntSet &b) { + if (is_logical_op::value) { + return IntervalSet::make(0, 1); + } + const IntervalSet* a_int = a.as(); + const IntervalSet* b_int = b.as(); + if (a_int && a_int->i.is_everything()) return a; + if (b_int && b_int->i.is_everything()) return b; + if (a_int && b_int) { + return CombineInterval(a_int->i, b_int->i); + } + if (a_int && !(a_int->i.is_bounded())) { + return CombineInterval_(a, b.cover_interval()); + } + if (b_int && !(b_int->i.is_bounded())) { + return CombineInterval_(a.cover_interval(), b); + } + return CombineSets(a, b); } +// Implementation of Evaluations and passing. void PassUp(const SplitNode* s, const std::unordered_map& dom_map, const IntSet& outer, @@ -215,33 +358,21 @@ void PassUp(const SplitNode* s, if (dom_map.count(s->outer) && dom_map.count(s->inner) && dom_map.count(s->parent) && - Match(outer, dom_map.at(s->outer)) && - Match(inner, dom_map.at(s->inner))) { - *parent = IntSet::make_range(dom_map.at(s->parent)); + MatchRange(outer, dom_map.at(s->outer)) && + MatchRange(inner, dom_map.at(s->inner))) { + *parent = IntSet::range(dom_map.at(s->parent)); return; } Expr factor = dom_map.at(s->inner)->extent; + Expr parent_min = dom_map.at(s->parent)->min; CHECK(outer.defined()); CHECK(inner.defined()); CHECK(factor.defined()); - // copy construct - auto n = std::make_shared(*(inner.operator->())); - - if (IsNumber(outer)) { - // shift the base offset - n->base = Range::make_with_min_extent( - AsNumber(outer) * factor + inner->base->min, - inner->base->extent); - } else { - // default use all domains in the data. - n->domain.push_back(outer->base); - n->stride.push_back(factor); - for (size_t i = 0; i < outer->domain.size(); ++i) { - n->domain.push_back(outer->domain[i]); - n->stride.push_back(outer->stride[i] * factor); - } - } - *parent = IntSet(n); + + *parent = Combine( + Combine( + Combine(outer, IntSet::single_point(factor)), inner), + IntSet::single_point(parent_min)); } void PassUp(const FuseNode* s, @@ -253,29 +384,51 @@ void PassUp(const FuseNode* s, CHECK(dom_map.count(s->inner)); CHECK(dom_map.count(s->fused)); - if (Match(fused, dom_map.at(s->fused))) { - *outer = IntSet::make_range(dom_map.at(s->outer)); - *inner = IntSet::make_range(dom_map.at(s->inner)); + if (MatchRange(fused, dom_map.at(s->fused))) { + *outer = IntSet::range(dom_map.at(s->outer)); + *inner = IntSet::range(dom_map.at(s->inner)); return; } - if (IsNumber(fused)) { - Expr value = AsNumber(fused); + Expr outer_min = dom_map.at(s->outer)->min; + Expr inner_min = dom_map.at(s->inner)->min; + + const IntervalSet* fused_int = fused.as(); + + if (fused_int && fused_int->i.is_single_point()) { + Expr value = fused_int->i.min; Expr factor = dom_map.at(s->inner)->extent; - *outer = IntSet::make_point(value / factor); - *inner = IntSet::make_point(value % factor); + Expr v_outer = value / factor; + Expr v_inner = value % factor; + if (!is_zero(outer_min)) v_outer = v_outer + outer_min; + if (!is_zero(inner_min)) v_inner = v_inner + inner_min; + *outer = IntSet::single_point(v_outer); + *inner = IntSet::single_point(v_inner); } else { LOG(WARNING) << "use fallback inference rule in fuse"; // simply use the entire set, this rule can be enhanced. - *outer = IntSet::make_range(dom_map.at(s->outer)); - *inner = IntSet::make_range(dom_map.at(s->inner)); + *outer = IntSet::range(dom_map.at(s->outer)); + *inner = IntSet::range(dom_map.at(s->inner)); + return; + } +} + + +void PassUp(const RebaseNode* s, + const std::unordered_map& dom_map, + const IntSet& rebased, + IntSet* parent) { + CHECK(dom_map.count(s->parent)); + if (MatchRange(rebased, dom_map.at(s->rebased))) { + *parent = IntSet::range(dom_map.at(s->parent)); return; } + Expr parent_min = dom_map.at(s->parent)->min; + *parent = Combine(rebased, IntSet::single_point(parent_min)); } -namespace { -// evaluator to evaluate the int set -class IRSetEvaluator { +// Evaluator to evalute the epxression. +class IntSetEvaluator { public: inline IntSet Eval(Expr expr) { static const FType& f = vtable(); @@ -283,11 +436,11 @@ class IRSetEvaluator { return f(expr, expr, this); } else { LOG(WARNING) << "cannot evaluate set type " << expr->type_key(); - return IntSet::make_all_set(); + return IntSet::everything(); } } - using FType = tvm::IRFunctor; + using FType = tvm::IRFunctor; static FType& vtable() { // NOLINT(*) static FType inst; return inst; } @@ -295,76 +448,84 @@ class IRSetEvaluator { std::unordered_map dom_map; }; -inline IntSet ConstOp(const NodeRef&, const Expr& e, IRSetEvaluator*) { - return IntSet::make_point(e); +inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator*) { + return IntSet::single_point(e); } -TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) +TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) .set_dispatch(ConstOp) .set_dispatch(ConstOp) .set_dispatch(ConstOp); -TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) -.set_dispatch([](const Variable* op, const Expr& e, IRSetEvaluator* m) { +TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) +.set_dispatch([](const Variable* op, const Expr& e, IntSetEvaluator* m) { auto it = m->dom_map.find(op); if (it != m->dom_map.end()) { return it->second; } else { - return IntSet::make_point(e); + return IntSet::single_point(e); } }); // binary operator template -inline IntSet Binary(const T* op, const Expr& e, IRSetEvaluator* m) { +inline IntSet Binary(const T* op, const Expr& e, IntSetEvaluator* m) { IntSet a = m->Eval(op->a); IntSet b = m->Eval(op->b); - if (IsNumber(a) && IsNumber(b)) { - if (Match(a, op->a) && - Match(b, op->b)) { - return IntSet::make_point(e); - } else { - return IntSet::make_point(T::make(AsNumber(a), AsNumber(b))); - } - } else { - return BinaryCombine(a, b); + if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { + return IntSet::single_point(e); } + IntSet r = Combine(a, b); + return r; } -TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) +TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) .set_dispatch(Binary) .set_dispatch(Binary) .set_dispatch(Binary) .set_dispatch
(Binary
) .set_dispatch(Binary) .set_dispatch(Binary) -.set_dispatch(Binary); - -// use simply bound for logical expressions for now. -inline IntSet Logical(const NodeRef&, const Expr& e, IRSetEvaluator*) { - return IntSet::make_range(Range::make_with_min_extent(0, 2)); -} - -TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) -.set_dispatch(Logical) -.set_dispatch(Logical) -.set_dispatch(Logical) -.set_dispatch(Logical) -.set_dispatch(Logical) -.set_dispatch(Logical) -.set_dispatch(Logical) -.set_dispatch(Logical); - -} // namespace +.set_dispatch(Binary) +.set_dispatch(Binary) +.set_dispatch(Binary) +.set_dispatch(Binary) +.set_dispatch(Binary) +.set_dispatch(Binary) +.set_dispatch(Binary) +.set_dispatch(Binary) +.set_dispatch(Binary); IntSet EvalSet(Expr e, const Map& dom_map) { - IRSetEvaluator m; + IntSetEvaluator m; for (auto kv : dom_map) { m.dom_map[kv.first->var.as()] = kv.second; } return m.Eval(e); } +IntSet EvalSet(Range r, + const Map& dom_map) { + IntSetEvaluator m; + for (auto kv : dom_map) { + m.dom_map[kv.first->var.as()] = kv.second; + } + IntSet min_set = m.Eval(r->min); + IntSet ext_set = m.Eval(r->extent).cover_interval(); + const Interval& ei = ext_set.as()->i; + if (!ei.has_upper_bound()) return IntSet::everything(); + ext_set = IntervalSet::make(0, ComputeExpr(ei.max, 1)); + return Combine(min_set, ext_set); +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const IntervalSet *op, IRPrinter *p) { + p->stream << "interval-set[" + << "[" << op->i.min << ", " + << op->i.max << ']'; + }); + + } // namespace schedule } // namespace tvm diff --git a/src/schedule/int_set.h b/src/schedule/int_set.h index 1428e5bf9c4a..5866c123d6c7 100644 --- a/src/schedule/int_set.h +++ b/src/schedule/int_set.h @@ -22,35 +22,48 @@ class IntSet : public NodeRef { public: /*! \brief constructor */ IntSet() {} - // constructor from not deontainer. + // constructor from not container. explicit IntSet(std::shared_ptr n) : NodeRef(n) {} - /*! \return whether the set is empty */ - inline bool is_empty() const { - return !defined(); - } - /*! - * \return a range that covers the IntSet - */ - Range GetCoverRange() const; /*! * \brief access the internal node container * \return the pointer to the internal node container */ inline const IntSetNode* operator->() const; /*! - * \param dom The domain to be created. - * \return create integer set from existing domain + * \brief Find a range that covers the region. + * \param max_range The range to be covered. + * \return The covering range. + */ + Range cover_range(Range max_range) const; + /*! + * \brief find an interval that covers the set. + * \return The covering interval set. */ - static IntSet make_range(Range dom); + IntSet cover_interval() const; + /*! \return Whether the set represent everything */ + bool is_everything() const; + /*! \return Whether the set is a single point */ + bool is_single_point() const; + /*! \return Whether the set contains everything */ + static IntSet everything(); /*! - * \param point - * \return create integer set that only contains one point + * \brief construct a point set. + * \param point The point in the set. + * \return construct a single point set */ - static IntSet make_point(Expr point); + static IntSet single_point(Expr point); /*! - * \return create integer set that represents everything + * \brief Construct a set representing a range. + * \param r The range + * \return constructed set. */ - static IntSet make_all_set(); + static IntSet range(Range r); +}; + +/*! + * \brief Base class of all IntSet containers. + */ +struct IntSetNode : public Node { }; /*! @@ -63,6 +76,18 @@ class IntSet : public NodeRef { */ IntSet EvalSet(Expr e, const Map& dom_map); + +/*! + * \brief Find an symbolic integer set that contains is union over + * all the possible conditional values in dom_map. + * + * \param r The initial range. + * \param dom_map The domain of each variable. + * \return An integer set that can cover all the possible values. + */ +IntSet EvalSet(Range r, + const Map& dom_map); + /*! * \brief Conditional upward message passing. * @@ -99,6 +124,23 @@ void PassUp(const FuseNode* s, const IntSet& fused, IntSet* outer, IntSet* inner); + +/*! + * \brief Conditional upward message passing. + * + * Get domain of parent, condition on domain of children. + * Domain is represented as IntSet. + * + * \param s The Fuse relation node. + * \param dom_map The old domain result from downward message passing. + * Contains the domain set if all the children are full set. + * \param rebased domain of rebased iteration. + * \param parent The result domain of parent iteration. + */ +void PassUp(const RebaseNode* s, + const std::unordered_map& dom_map, + const IntSet& fused, + IntSet* parent); /*! * \brief Create an union set of all sets * \param sets The sets to be unioned @@ -106,6 +148,11 @@ void PassUp(const FuseNode* s, */ IntSet Union(const Array& sets); +// implementation +inline const IntSetNode* IntSet::operator->() const { + return static_cast(node_.get()); +} + } // namespace schedule } // namespace tvm diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index 1f266126efed..58368ceb93b4 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -81,7 +81,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) } } CHECK(found) - << "Cannot compute at a iteration variable that is not part of parent leaf vars"; + << "Cannot find the specified axis in parent stage's leaf_iter_vars"; return *this; } @@ -165,7 +165,6 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent, return *this; } - Schedule::Schedule(Array ops) { auto n = std::make_shared(); n->roots = ops; @@ -203,9 +202,53 @@ IterVarRelation FuseNode::make( return IterVarRelation(n); } +IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) { + auto n = std::make_shared(); + n->parent = parent; + n->rebased = rebased; + return IterVarRelation(n); +} + +void Schedule::normalize() { + std::unordered_map rebase_map; + std::unordered_map attach_mark; + + + for (Stage s : (*this)->stages) { + if (s->attach_type == kScope) { + attach_mark[s->attach_stage.get()] = 1; + } + } + + for (Stage s : (*this)->stages) { + if (!attach_mark.count(s.get())) continue; + auto root_iter_vars = s->op->root_iter_vars(); + ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite(); + + for (IterVar iv : root_iter_vars) { + size_t idx = FindIterVar(leaf_vars, iv); + if (idx < leaf_vars->data.size()) { + // insert rebase + IterVar rebased(Range(), iv->var->name_hint + ".rb"); + s->relations.push_back(RebaseNode::make(iv, rebased)); + leaf_vars->data[idx] = rebased.node_; + rebase_map[iv] = rebased; + } + } + } + // remap the parent relation + for (Stage s : (*this)->stages) { + if (s->attach_type != kScope) continue; + if (rebase_map.count(s->attach_ivar)) { + s->attach_ivar = rebase_map.at(s->attach_ivar); + } + } +} + TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(SplitNode); TVM_REGISTER_NODE_TYPE(FuseNode); +TVM_REGISTER_NODE_TYPE(RebaseNode); TVM_REGISTER_NODE_TYPE(ScheduleNode); } // namespace tvm diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc new file mode 100644 index 000000000000..61f0347bcd2b --- /dev/null +++ b/src/schedule/schedule_ops.cc @@ -0,0 +1,388 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file schedule_ops.cc + */ +#include +#include +#include +#include +#include + +#include "../pass/ir_util.h" +#include "./int_set.h" +#include "./graph.h" + +namespace tvm { +namespace schedule { + +using namespace ir; + +/*! + * \brief message passing to find if IterVar is related to reduction. + * \param s The stage to be used. + * \param p_state The message passing state + * IterVar->flag + */ +void PassDownFlag(const Stage& s, + std::unordered_map* p_state) { + auto& state = *p_state; + for (IterVarRelation rel : s->relations) { + if (rel.as()) { + const SplitNode* s = rel.as(); + int flag = state.at(s->parent); + state[s->outer] = flag; + state[s->inner] = flag; + } else if (rel.as()) { + const FuseNode* s = rel.as(); + int flag_outer = state.at(s->outer); + int flag_inner = state.at(s->inner); + state[s->fused] = flag_outer | flag_inner; + } else if (rel.as()) { + const RebaseNode* s = rel.as(); + int flag = state.at(s->parent); + state[s->rebased] = flag; + } else { + LOG(FATAL) << "unknown relation type"; + } + } +} + +/*! + * \brief use message passing to calculate the assignment of each Var inside the loop body. + * \param s The schedule to be used. + * \param dom_map The domain map of each iteration variable's domain + * \param p_state The message passing state + * IterVar->The assignment. + */ +void PassUpOffset(const Stage& s, + const Map& dom_map, + std::unordered_map* p_state) { + auto& state = *p_state; + for (size_t i = s->relations.size(); i != 0; --i) { + IterVarRelation rel = s->relations[i - 1]; + if (rel.as()) { + const SplitNode* s = rel.as(); + Expr outer = state.at(s->outer); + Expr inner = state.at(s->inner); + Expr factor = dom_map.at(s->inner)->extent; + Expr parent_min = dom_map.at(s->parent)->min; + state[s->parent] = inner + outer * factor; + // add min if they exist + if (!is_zero(parent_min)) { + state[s->parent] = state[s->parent] + parent_min; + } + } else if (rel.as()) { + const FuseNode* s = rel.as(); + Expr value = state.at(s->fused); + Expr factor = dom_map.at(s->inner)->extent; + Expr outer_min = dom_map.at(s->outer)->min; + Expr inner_min = dom_map.at(s->inner)->min; + state[s->outer] = value / factor; + state[s->inner] = value % factor; + // add min if they exist + if (!is_zero(outer_min)) { + state[s->outer] = state[s->outer] + outer_min; + } + if (!is_zero(inner_min)) { + state[s->inner] = state[s->inner] + inner_min; + } + } else if (rel.as()) { + const RebaseNode* s = rel.as(); + Expr value = state.at(s->rebased); + Expr parent_min = dom_map.at(s->parent)->min; + // add min if they exist + if (!is_zero(parent_min)) { + state[s->parent] = value + parent_min; + } else { + state[s->parent] = value; + } + } else { + LOG(FATAL) << "unknown relation type"; + } + } +} + +std::vector > +MakeLoopNest(const Stage& sch, + const Map& dom_map, + size_t begin_loop, + bool reduce_init_loop, + std::unordered_map* p_value_map, + const std::unordered_map& skip_iter) { + auto leaf_iter_vars = sch->leaf_iter_vars; + Stmt no_op = Evaluate::make(0); + // create the loop nest + std::vector > nest; + nest.resize(leaf_iter_vars.size() + 1); + std::unordered_map& value_map = *p_value_map; + + for (size_t i = begin_loop; i < leaf_iter_vars.size(); ++i) { + auto iv = leaf_iter_vars[i]; + if (skip_iter.count(iv) && skip_iter.at(iv)) { + // skip this iteration. + value_map[iv] = iv->var; + continue; + } + + Range dom = dom_map.at(iv); + // initialize the offset and loop_level + Var var = iv->var; + if (reduce_init_loop) { + var = Var(iv->var->name_hint + ".init", iv->var.type()); + } + // Mark the iter var in the IR, to remember the point + if (iv->thread_tag.length() == 0) { + if (is_one(dom->extent)) { + nest[i + 1].emplace_back( + LetStmt::make(var, dom->min, no_op)); + value_map[iv] = dom->min; + } else if (is_zero(dom->min)) { + nest[i + 1].emplace_back( + For::make(var, 0, dom->extent, + ForType::Serial, DeviceAPI::None, no_op)); + value_map[iv] = var; + } else { + Var idx(iv->var->name_hint + ".idx", iv->var.type()); + nest[i + 1].emplace_back( + For::make(idx, 0, dom->extent, + ForType::Serial, DeviceAPI::None, no_op)); + Expr new_value = dom->min + idx; + value_map[iv] = new_value; + nest[i + 1].emplace_back( + LetStmt::make(var, new_value, no_op)); + } + } else { + // Always restrict threaded IterVar to starts from 0. + CHECK(is_zero(dom->min)); + // annotate the extent of the IterVar + nest[i + 1].emplace_back( + AttrStmt::make(iv, "thread_extent", dom->extent, no_op)); + value_map[iv] = var; + } + if (!reduce_init_loop) { + // annotate the extent of the IterVar + nest[i + 1].emplace_back( + AttrStmt::make(iv, "scope", iv->var, no_op)); + } + } + // message passing to get offset of root iter vars. + PassUpOffset(sch, dom_map, &value_map); + return nest; +} + +Stmt MakeLoop(const Stage& s, + const Map& dom_map, + Stmt provide, + Stmt init) { + std::unordered_map value_map; + auto nest = MakeLoopNest(s, dom_map, 0, false, &value_map, {}); + provide = Substitute(provide, value_map); + if (init.defined()) { + // try to find the location to insert the initialization. + // Fuse the initialization and provide loop when possible. + std::unordered_map reduce_state; + const ComputeOpNode* compute = s->op.as(); + for (IterVar iv : compute->reduce_axis) { + reduce_state[iv] = 2; + } + for (IterVar iv : compute->axis) { + reduce_state[iv] = 1; + } + // find which iter var is related to reduction and which is related to axis. + PassDownFlag(s, &reduce_state); + auto leaf_iter_vars = s->leaf_iter_vars; + std::unordered_map init_value_map; + // first first loop that is related to reduction. + size_t begin_loop = leaf_iter_vars.size(); + for (size_t i = 0; i < leaf_iter_vars.size(); ++i) { + auto iv = leaf_iter_vars[i]; + int flag = reduce_state.at(iv); + if ((flag & 2) != 0) { + begin_loop = i; break; + } + init_value_map[iv] = value_map.at(iv); + } + // skip loops that does not relates to axis. + std::unordered_map skip_iter; + for (size_t i = begin_loop; i < leaf_iter_vars.size(); ++i) { + auto iv = leaf_iter_vars[i]; + int flag = reduce_state.at(iv); + if ((flag & 1) == 0) skip_iter[iv] = true; + } + auto init_nest = MakeLoopNest( + s, dom_map, begin_loop, true, &init_value_map, skip_iter); + init = Substitute(init, init_value_map); + init = MergeNest(init_nest, init); + // common nest + std::vector > common(nest.begin(), nest.begin() + begin_loop); + std::vector > reduce(nest.begin() + begin_loop, nest.end()); + provide = MergeNest(reduce, provide); + return MergeNest( + common, Block::make(init, provide)); + } else { + return MergeNest(nest, provide); + } +} + +Stmt MakeProvide(const ComputeOpNode* op, + const std::vector& tensors) { + Tensor t = tensors[0]; + Array args; + for (IterVar iv : op->axis) { + args.push_back(iv->var); + } + return Provide::make(t->op, t->value_index, op->body, args); +} + +Stmt MakeRealize(const ComputeOpNode* op, + const Map& dom_map, + const std::vector& tensors, + Stmt body) { + Tensor t = tensors[0]; + Halide::Internal::Region bounds; + for (IterVar iv : op->axis) { + bounds.push_back(dom_map.at(iv)); + } + return Realize::make(t->op, t->value_index, t->dtype, + bounds, make_const(Bool(1), true), body); +} + + +void MakeReduction(const ComputeOpNode* op, + const std::vector& tensors, + const Map& dom_map, + Stmt* init, + Stmt* provide) { + Stmt no_op = Evaluate::make(0); + Tensor t = tensors[0]; + std::vector nest; + Array args; + for (IterVar iv : op->axis) { + args.push_back(iv->var); + } + const Reduce* reduce = op->body.as(); + CHECK(reduce); + Expr init_value, update_value; + if (reduce->op == "Add") { + init_value = make_zero(reduce->type); + update_value = Add::make(t(args), reduce->source); + } else if (reduce->op == "Max") { + init_value = reduce->type.min(); + update_value = Max::make(t(args), reduce->source); + } else if (reduce->op == "Min") { + init_value = reduce->type.max(); + update_value = Min::make(t(args), reduce->source); + } else { + LOG(FATAL) << "Unsupported reduction " << reduce->op; + } + *init = Provide::make(t->op, t->value_index, init_value, args); + *provide = Provide::make(t->op, t->value_index, update_value, args); +} + +Stmt MakePipeline(const Stage& sch, + const Map& dom_map, + Stmt consumer) { + std::vector tensors; + for (int i = 0; i < sch->op->num_outputs(); ++i) { + tensors.emplace_back(sch->op.output(i)); + } + + Stmt init, provide; + + const ComputeOpNode* compute = sch->op.as(); + if (compute) { + if (compute->reduce_axis.size() == 0) { + provide = MakeProvide(compute, tensors); + } else { + MakeReduction(compute, tensors, dom_map, &init, &provide); + } + } else { + LOG(FATAL) << "not supported op " << sch->op->type_key(); + } + + Stmt producer = MakeLoop(sch, dom_map, provide, init); + producer = ProducerConsumer::make(sch->op, true, producer); + + Stmt pipeline = producer; + if (consumer.defined()) { + consumer = ProducerConsumer::make(sch->op, false, consumer); + pipeline = Block::make(producer, consumer); + } + + if (sch->op.as()) { + return MakeRealize(sch->op.as(), + dom_map, tensors, pipeline); + } else { + LOG(FATAL) << "not supported op"; + return Stmt(); + } +} + +// inject the operator's realization on the stmt. +class InjectRealize : public IRMutator { + public: + InjectRealize(Stage schedule, Map dom_map) + : schedule(schedule), dom_map(dom_map) {} + + Stmt Mutate(Stmt stmt) final { + CHECK(stmt.defined()); + stmt = IRMutator::Mutate(stmt); + const AttrStmt* op = stmt.as(); + if (op != nullptr && + op->type_key == "scope") { + if (op->node == schedule->attach_ivar) { + CHECK(!found_attach); + found_attach = true; + stmt = AttrStmt::make( + op->node, op->type_key, op->value, + MakePipeline(schedule, dom_map, + IRMutator::Mutate(op->body))); + } + } + return stmt; + } + // the operations to be carried + Stage schedule; + // domain map + Map dom_map; + // whether attach point is found + bool found_attach{false}; +}; + +Stmt InjectInline(const Operation op, Stmt body) { + CHECK(body.defined()); + const ComputeOpNode* compute = op.as(); + CHECK(compute != nullptr) + << "can only inline compute op"; + Array args; + for (auto iv : compute->axis) { + args.push_back(iv->var); + } + return Inline(body, op, args, compute->body); +} + +Stmt ScheduleOps( + Schedule sch, Map dom_map) { + Stmt body = Stmt(); + // reverse the post DFS order. + for (size_t i = sch->stages.size(); i != 0; --i) { + Stage s = sch->stages[i - 1]; + // no need to specify place holder op. + if (s->op.as()) continue; + if (s->attach_type == kInline) { + body = InjectInline(s->op, body); + } else if (s->attach_type == kRoot || s-> attach_type == kNone) { + body = MakePipeline(s, dom_map, body); + } else if (s->attach_type == kScope) { + CHECK(body.defined()); + InjectRealize mutator(s, dom_map); + body = mutator.Mutate(body); + CHECK(mutator.found_attach) + << "did not find attachment point"; + } + } + return body; +} + +} // namespace schedule +} // namespace tvm diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 0395d633b5ed..6f80bc150bd7 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -18,7 +18,8 @@ def test_add(): # one line to build the function. codes = [] - fadd = tvm.build(s, args=[A, B, C], + fadd = tvm.build(s, + args=[A, B, C], target="cuda", name="myadd", record_codes=codes) for c in codes: diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py new file mode 100644 index 000000000000..30646cf6b9e6 --- /dev/null +++ b/tests/python/integration/test_reduce.py @@ -0,0 +1,45 @@ +import tvm +import numpy as np + +def test_sum(): + # graph + n = tvm.Var('n') + m = tvm.Var('m') + A = tvm.placeholder((n, m), name='A') + k = tvm.IterVar((0, m)) + B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B') + # schedule + s = tvm.Schedule(B.op) + # create iter var and assign them tags. + num_thread = 1 + block_x = tvm.IterVar(thread_tag="blockIdx.x") + thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") + _, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x) + _, x = s[B].split(x, outer=thread_x) + + tvm.init_opencl() + codes = [] + fsum = tvm.build(s, + args=[A, B], + target="opencl", name="myadd", + record_codes=codes) + for c in codes: + print(c) + num_device = 1 + for i in range(num_device): + ctx = tvm.opencl(i) + if not ctx.enabled: + continue + # launch the kernel. + n = 1028 + m = 129 + #a = tvm.nd.array(np.zeros((n, m)).astype(A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx) + fsum(a, b) + np.testing.assert_allclose( + b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4) + + +if __name__ == "__main__": + test_sum() diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index e8aba60a9af8..56a9c29d8c2a 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -18,8 +18,7 @@ def test_add_pipeline(): # compile to IR bounds = tvm.schedule.InferBound(s) - stmt = tvm.ir_pass.ScheduleOps(s, bounds) - + stmt = tvm.schedule.ScheduleOps(s, bounds) Ab = tvm.Buffer(A.shape, A.dtype, name='A') Bb = tvm.Buffer(B.shape, B.dtype, name='B') Cb = tvm.Buffer(C.shape, C.dtype, name='C') diff --git a/tests/python/unittest/test_codegen_makeapi.py b/tests/python/unittest/test_codegen_makeapi.py index fd6522a2d50c..689556db9f28 100644 --- a/tests/python/unittest/test_codegen_makeapi.py +++ b/tests/python/unittest/test_codegen_makeapi.py @@ -10,12 +10,13 @@ def test_makeapi(): s = tvm.Schedule(C.op) bounds = tvm.schedule.InferBound(s) - stmt = tvm.ir_pass.ScheduleOps(s, bounds) + stmt = tvm.schedule.ScheduleOps(s, bounds) Ab = tvm.Buffer(A.shape, A.dtype, name='A') Bb = tvm.Buffer(B.shape, B.dtype, name='B') Cb = tvm.Buffer(C.shape, C.dtype, name='C') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) + num_packed_args = 2 f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args) assert(f.handle_data_type[Ab.data].dtype == Ab.dtype) diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 01ab5109f628..9d9115f5c2ea 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -26,7 +26,7 @@ def test_tensor_reduce(): B = tvm.placeholder((n, l), name='B') T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k]) rv = tvm.IterVar((0, A.shape[1]), name="k") - C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), rdom=rv)) + C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), axis=rv)) # json load save C_json = tvm.save_json(C) C_loaded = tvm.load_json(C_json) diff --git a/tests/python/unittest/test_pass_storage_flatten.py b/tests/python/unittest/test_pass_storage_flatten.py index 98200bc7d528..a981b513a34b 100644 --- a/tests/python/unittest/test_pass_storage_flatten.py +++ b/tests/python/unittest/test_pass_storage_flatten.py @@ -12,7 +12,7 @@ def test_flatten2(): s[A1].compute_at(s[A2], xo) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) - stmt = tvm.ir_pass.ScheduleOps(s, bounds) + stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) Ab = tvm.Buffer(A.shape, A.dtype, name='A') diff --git a/tests/python/unittest/test_pass_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py similarity index 89% rename from tests/python/unittest/test_pass_schedule_ops.py rename to tests/python/unittest/test_schedule_schedule_ops.py index e634d0773b0c..feed951e295f 100644 --- a/tests/python/unittest/test_pass_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -11,7 +11,7 @@ def test_schedule0(): bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) - stmt = tvm.ir_pass.ScheduleOps(s, bounds) + stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) def test_schedule1(): @@ -24,7 +24,7 @@ def test_schedule1(): xo, xi = s[A1].split(A1.op.axis[0], 8) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) - stmt = tvm.ir_pass.ScheduleOps(s, bounds) + stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) def test_schedule2(): @@ -39,7 +39,7 @@ def test_schedule2(): s[A1].compute_at(s[A2], xo) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) - stmt = tvm.ir_pass.ScheduleOps(s, bounds) + stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt)