From 2df0854a2b27778470ad3f4199506e8c9769200a Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Thu, 22 Jul 2021 04:59:06 +0800 Subject: [PATCH] [TensorIR][M2a] Fuse, Split (#8467) * Fuse&split (#408) Co-authored-by: jinhongyi <323195289@qq.com> Co-authored-by: Junru Shao --- include/tvm/arith/iter_affine_map.h | 12 + include/tvm/tir/schedule/schedule.h | 19 + python/tvm/tir/schedule/schedule.py | 138 +++++- src/arith/iter_affine_map.cc | 15 + src/arith/rewrite_simplify.cc | 4 + src/tir/schedule/concrete_schedule.cc | 87 ++++ src/tir/schedule/concrete_schedule.h | 43 +- src/tir/schedule/primitive.h | 22 +- .../schedule/primitive/loop_transformation.cc | 389 +++++++++++++++ src/tir/schedule/schedule.cc | 2 + .../unittest/test_tir_schedule_split_fuse.py | 453 ++++++++++++++++++ 11 files changed, 1170 insertions(+), 14 deletions(-) create mode 100644 src/tir/schedule/primitive/loop_transformation.cc create mode 100644 tests/python/unittest/test_tir_schedule_split_fuse.py diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index d671339fb66b..6c72cbeafdd4 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -282,6 +282,18 @@ class IterSumExpr : public IterMapExpr { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer); +/*! + * \brief Use IterVarMap detector to rewrite and simplify the indices + * + * \param indices The indices to detect pattern for. + * \param input_iters Map from variable to iterator's range. + * \param input_pred The predicate constraints on the input iterators + * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. + * + * \return The indices after rewrite + */ +Array IterMapSimplify(const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, bool require_bijective); /*! * \brief Apply the inverse of the affine transformation to the outputs. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9a09d0ad211f..868454b18b74 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -196,6 +196,25 @@ class ScheduleNode : public runtime::Object { */ virtual Array GetLoops(const BlockRV& block_rv) = 0; /******** Schedule: loops manipulation ********/ + /*! + * \brief Fuse a list of consecutive loops into one. It requires: + * 1) The loops can't have annotations or thread bindings. + * 2) The (i+1)-th loop must be the only child of the i-th loop. + * 3) All loops must start with 0. + * \param loop_rvs The loops to be fused + * \return The new loop after fusion + */ + virtual LoopRV Fuse(const Array& loop_rvs) = 0; + /*! + * \brief Split a loop into a list of consecutive loops. It requires: + * 1) The loop can't have annotation or thread binding. + * 2) The loop must start with 0. + * \param loop_rv The loop to be split + * \param factors The tiling factors, and at most one of which is -1, which means that + * factor is inferred. + * \return The new loops after split + */ + virtual Array Split(const LoopRV& loop_rv, const Array>& factors) = 0; /******** Schedule: compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 2091f4d80ab3..a71e2e1241be 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=unused-import """The TensorIR schedule class""" -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error @@ -43,7 +43,10 @@ class BlockRV(Object): """A random variable that refers to a block""" -ExprRV = PrimExpr # A random variable that evaluates to an integer +# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370 +# This feature is not supported until python 3.10: +# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias +ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # type: ignore # pylint: disable=invalid-name @@ -257,6 +260,137 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]: return _ffi_api_schedule.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member ########## Schedule: loops manipulation ########## + def fuse(self, *loops: List[LoopRV]) -> LoopRV: + """Fuse a list of consecutive loops into one. It requires: + 1) The loops can't have annotations or thread bindings. + 2) The (i+1)-th loop must be the only child of the i-th loop. + 3) All loops must start with 0. + + Parameters + ---------- + *loops : List[LoopRV] + The loops to be fused + + Returns + ---------- + fused_loop : LoopRV + The new loop after fusion + + Examples + -------- + + Before applying fuse, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_fuse(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do fuse: + + .. code-block:: python + + sch = tir.Schedule(before_fuse) + i, j = sch.get_loops(sch.get_block("B")) + sch.fuse(i, j) + print(tvm.script.asscript(sch.mod["main"])) + + After applying fuse, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_fuse(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + # the 2 loops are fused into 1 + for i_j_fused in tir.serial(0, 16384): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 128)) + tir.bind(vj, tir.floormod(i_j_fused, 128)) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + return _ffi_api_schedule.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member + + def split( + self, + loop: LoopRV, + factors: List[Union[ExprRV, None]], + ) -> List[LoopRV]: + """Split a loop into a list of consecutive loops. It requires: + 1) The loop can't have annotation or thread binding. + 2) The loop must start with 0. + Predicates may be added to ensure the total loop numbers keeps unchanged. + In `factors`, at most one of the factors can be None, + which will be automatically inferred. + + Parameters + ---------- + loop : LoopRV + The loop to be split + + factors: List[Union[ExprRV, None]] + The splitting factors + Potential inputs are: + - None + - ExprRV + - Nonnegative constant integers + + Returns + ---------- + split_loops : List[LoopRV] + The new loops after split + + Examples + -------- + + Before split, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do fuse: + + .. code-block:: python + + sch = tir.Schedule(before_split) + i, j = sch.get_loops(sch.get_block("B")) + sch.split(i, factors=[2, 64]) + print(tvm.script.asscript(sch.mod["main"])) + + After applying split, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + # the original loop is split into 2 loops + for i0, i1, j in tir.grid(2, 64, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, ((i0*64) + i1)) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + # it will be checked later in C++ implementation + # that there is at most one None in `factors` + return _ffi_api_schedule.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member + ########## Schedule: compute location ########## def compute_inline(self, block: BlockRV) -> None: """Inline a block into its consumer(s). It requires: diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index cd482279efe0..ac78c55ed610 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1085,6 +1085,21 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter return NormalizeIterMapToExpr(expr); }); +Array IterMapSimplify(const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, bool require_bijective) { + Analyzer analyzer; + Array rewrite = + DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer); + if (rewrite.empty()) { + return indices; + } + Array res; + res.reserve(rewrite.size()); + IterMapToExprNormalizer converter(&analyzer); + for (const auto& expr : rewrite) res.push_back(converter.Convert(expr)); + return res; +} + /*! * \brief Divider to divide the bindings into two sets of bindings(outer and inner) * such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X. diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index a58e4433dadd..ff6536ab066b 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -799,6 +799,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1, x * c2), floordiv(c1, c2), c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); @@ -882,6 +884,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); + TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x)); TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y)); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 0563d39427b1..0d5bfce46e37 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -258,6 +258,93 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { } /******** Schedule: loops manipulation ********/ + +LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { + CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; + Array loop_srefs = this->GetSRefs(loop_rvs); + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::Fuse(state_, loop_srefs); + TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, + const Array>& factor_rvs) { + class NotSingleInferFactorError : public ScheduleError { + public: + explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} + + String FastErrorString() const final { + return "ScheduleError: only one factor can be specified as -1 or none"; + } + + String DetailRenderTemplate() const final { + return "Only one factor can be specified as -1 or none"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + }; + + class WrongFactorProductError : public ScheduleError { + public: + explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The product of factors is not larger than or equal to the extent of " + "loop"; + } + + String DetailRenderTemplate() const final { + return "The product of factors is not larger than or equal to the extent of loop {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; + }; + // Prepare for the splitting + StmtSRef loop_sref = this->GetSRef(loop_rv); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + Array factors; + factors.reserve(factor_rvs.size()); + int infer_index = -1; + PrimExpr tot_length = 1; + Array results; + TVM_TIR_SCHEDULE_BEGIN(); + // infer factor if needed and check validity of factors + for (size_t i = 0; i < factor_rvs.size(); i++) { + if (!factor_rvs[i].defined()) { + factors.push_back(Integer(-1)); + if (infer_index == -1) { + infer_index = i; + } else { + throw NotSingleInferFactorError(state_->mod); + } + } else { + PrimExpr factor = this->Get(factor_rvs[i].value()); + factors.push_back(factor); + tot_length *= factor; + } + } + if (infer_index != -1) { + factors.Set(infer_index, + this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); + } else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) { + throw WrongFactorProductError(state_->mod, GetRef(loop)); + } + results = tir::Split(state_, loop_sref, factors); + TVM_TIR_SCHEDULE_END("split", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(results); +} + /******** Schedule: compute location ********/ void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 8945fb9ee0dc..fab3e259b752 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -68,6 +68,8 @@ class ConcreteScheduleNode : public ScheduleNode { inline PrimExpr Get(const ExprRV& expr_rv) const final; inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; + inline Array GetSRefs(const Array& rvs) const; + inline Array GetSRefs(const Array& rvs) const; void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } @@ -78,6 +80,8 @@ class ConcreteScheduleNode : public ScheduleNode { BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; /******** Schedule: loops manipulation ********/ + LoopRV Fuse(const Array& loop_rvs) override; + Array Split(const LoopRV& loop_rv, const Array>& factors) override; /******** Schedule: compute location ********/ void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; @@ -143,17 +147,16 @@ inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { } inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { - auto it = this->symbol_table_.find(expr_rv); - if (it == this->symbol_table_.end()) { - LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << expr_rv; - } - const ObjectRef& obj = (*it).second; - const auto* expr_node = obj.as(); - if (expr_node == nullptr) { - LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: " - << (obj.defined() ? obj->GetTypeKey() : "None"); - } - return GetRef(expr_node); + PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> Optional { + auto it = this->symbol_table_.find(var); + if (it == this->symbol_table_.end()) { + LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var; + } + const ObjectRef& obj = (*it).second; + const auto* int_imm = TVM_TYPE_AS(int_imm, obj, IntImmNode); + return Integer(int_imm->value); + }); + return this->analyzer_->Simplify(transformed); } inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { @@ -198,6 +201,24 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { return GetRef(sref); } +template +inline Array GetSRefsHelper(const ConcreteScheduleNode* sch, const Array& rvs) { + Array result; + result.reserve(rvs.size()); + for (const T& rv : rvs) { + result.push_back(sch->GetSRef(rv)); + } + return result; +} + +inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { + return GetSRefsHelper(this, rvs); +} + +inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { + return GetSRefsHelper(this, rvs); +} + /******** Adding/Removing elements in the symbol table ********/ template diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index ab8299e38169..088c4df58859 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -25,7 +25,27 @@ namespace tvm { namespace tir { /******** Schedule: loops manipulation ********/ - +/*! + * Split a loop into a list of consecutive loops. It requires: + * 1) The loop can't have annotation or thread binding. + * 2) The loop must start with 0. + * \param self The state of the schedule + * \param loop_sref The sref to the loop being split + * \param factors The splitting factors + * \return An array of srefs to the loops after splitting + */ +TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, + const Array& factors); +/*! + * \brief Fuse a list of consecutive loops into one. It requires: + * 1) The loops can't have annotations or thread bindings. + * 2) The inner loop must be the only child of the outer loop. + * 3) All loops must start with 0. + * \param self The state of the schedule + * \param loop_srefs An array of srefs to the loops to be fused + * \return The sref to the fused loop + */ +TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); /******** Schedule: compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc new file mode 100644 index 000000000000..2a2d9ed2a888 --- /dev/null +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief Append a new predicate to the each child of type BlockRealize (not recursively) */ +class BlockPredicateAppender : public StmtMutator { + public: + /*! + * \brief Constructor + * \param to_append The predicate to be appended to BlockRealizeNode + */ + explicit BlockPredicateAppender(const PrimExpr& to_append) : to_append_(to_append) {} + + private: + // For each direct child of type BlockRealizeNode, append the predicate + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + // We do not recursively do this + ObjectPtr n = CopyOnWrite(realize); + n->predicate = n->predicate && to_append_; + return BlockRealize(n); + } + + /*! \brief The predicate to be appended */ + const PrimExpr& to_append_; +}; + +/*! \brief Substitute vars and collect the reuse mapping of opaque blocks */ +class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { + public: + explicit SubstituteVarAndCollectOpaqueBlock(std::function(const Var&)> vmap, + Map* opaque_blocks) + : vmap_(vmap), opaque_blocks_(opaque_blocks) {} + + private: + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + if (Optional ret = vmap_(var)) { + return ret.value(); + } else { + return std::move(var); + } + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + BlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); + if (realize->block->iter_vars.empty()) { + opaque_blocks_->Set(op->block, realize->block); + } + return std::move(realize); + } + + /*! \brief The substitute function */ + std::function(const Var&)> vmap_; + /*! \brief The reuse mapping of opaque blocks */ + Map* opaque_blocks_; +}; + +/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ +class IterMapSimplifyBlockBinding : public StmtExprMutator { + public: + explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map loop_var2extent) + : opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent) {} + + static For SimplifyBindings(Stmt stmt, const Array& loop_srefs, + MapNode* opaque_blocks) { + Map loop_var2extent; + for (const StmtSRef& sref : loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + return Downcast( + IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent))(std::move(stmt))); + } + + private: + Stmt VisitStmt_(const ForNode* op) final { + loop_var2extent_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + Stmt res = StmtMutator::VisitStmt_(op); + loop_var2extent_.erase(op->loop_var); + return res; + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + // skip opaque block and update mapping + if (op->iter_values.empty()) { + Block block = op->block; + BlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); + for (const std::pair& entry : *opaque_blocks_) { + if (entry.second.same_as(block)) { + opaque_blocks_->at(entry.first) = realize->block; + break; + } + } + return std::move(realize); + } + Array v = arith::IterMapSimplify(/*indices=*/op->iter_values, + /*input_iters=*/loop_var2extent_, + /*input_pred=*/op->predicate, + /*require_bijective=*/false); + if (v.same_as(op->iter_values)) { + return GetRef(op); + } else { + ObjectPtr n = CopyOnWrite(op); + n->iter_values = std::move(v); + return Stmt(n); + } + } + + /*! \brief The reuse mapping */ + MapNode* opaque_blocks_; + /*! \brief The range of loops */ + Map loop_var2extent_; +}; + +class HasAnnotationOrThreadBindingError : public ScheduleError { + public: + explicit HasAnnotationOrThreadBindingError(IRModule mod, For loop) + : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The primitive can't be applied because the loop has annotation or " + "thread binding"; + } + + String DetailRenderTemplate() const final { + return "The primitive can't be applied because the loop {0} has annotation or thread binding"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class OuterNotInnerParent : public ScheduleError { + public: + explicit OuterNotInnerParent(IRModule mod, For outer, For inner) + : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} + + String FastErrorString() const final { + return "ScheduleError: The outer loop is not the parent of the inner loop"; + } + + String DetailRenderTemplate() const final { + return "The loops can't be fused because the outer loop {0} is not the parent of the inner " + "loop {1}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {outer_, inner_}; } + + IRModule mod_; + For outer_; + For inner_; +}; + +class NotOnlyChildError : public ScheduleError { + public: + explicit NotOnlyChildError(IRModule mod, For outer, For inner) + : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} + + String FastErrorString() const final { + return "ScheduleError: The inner loop is not the only child of outer loop"; + } + + String DetailRenderTemplate() const final { + return "The loops can't be fused because the inner loop {1} is not the only child of outer " + "loop {0}."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {outer_, inner_}; } + + IRModule mod_; + For outer_; + For inner_; +}; + +class LoopNotStartWithZeroError : public ScheduleError { + public: + explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The primitive only supports loop starting with 0"; + } + + String DetailRenderTemplate() const final { + return "The loop {0} does not start with 0, which is not supported"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class NotSingleInferFactorError : public ScheduleError { + public: + explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} + + String FastErrorString() const final { + return "ScheduleError: only one factor can be specified as -1 or none"; + } + + String DetailRenderTemplate() const final { + return "Only one factor can be specified as -1 or none"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; +}; + +class WrongFactorProductError : public ScheduleError { + public: + explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The product of factors is not larger than or equal to the extent of " + "loop"; + } + + String DetailRenderTemplate() const final { + return "The product of factors is not larger than or equal to the extent of loop {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +Array Split(ScheduleState self, const StmtSRef& loop_sref, + const Array& factors) { + // Invariance + // - The total repeat number has not changed for each direct child block with updating predicate. + // - The execution order has not changed. (The block executes with the same args and the same + // order with before. + // Step 1. Check correctness + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (!loop->annotations.empty() || loop->thread_binding.defined()) { + throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + } + // Currently, loops not starting with 0 are not supported + arith::Analyzer analyzer; + if (!analyzer.CanProve(loop->min == 0)) { + throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + } + // Step 2. Replace all occurrences of the original loop var with new variables + int n = factors.size(); + PrimExpr substitute_value = 0; + std::vector new_loop_vars; + new_loop_vars.reserve(n); + for (int i = 0; i < n; i++) { + const PrimExpr& factor = factors[i]; + Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); + substitute_value = substitute_value * factor + var; + analyzer.Bind(var, Range::FromMinExtent(0, factor)); + new_loop_vars.emplace_back(std::move(var)); + } + Map opaque_block_reuse; + Stmt new_stmt = loop->body; + new_stmt = SubstituteVarAndCollectOpaqueBlock( + [&](const Var& v) -> Optional { + if (v.same_as(loop->loop_var)) { + return substitute_value; + } else { + return NullOpt; + } + }, + &opaque_block_reuse)(std::move(new_stmt)); + // Step 3. Update predicate to guard the loop + PrimExpr predicate = substitute_value < loop->extent; + if (!analyzer.CanProve(predicate)) { + new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt)); + } + // Step 4. Generate nested loops to replace the original loop and simplify the binding + for (int i = n - 1; i >= 0; i--) { + new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt); + } + new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref), + opaque_block_reuse.CopyOnWrite()); + self->Replace(loop_sref, new_stmt, opaque_block_reuse); + Array result_srefs; + result_srefs.reserve(n); + for (int i = 0; i < n; i++) { + result_srefs.push_back(self->stmt2ref.at(new_stmt.get())); + const ForNode* outer_loop = TVM_TYPE_AS(outer_loop, new_stmt, ForNode); + new_stmt = outer_loop->body; + } + return result_srefs; +} + +StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { + // Invariance + // - The total repeat number has not changed for each direct child block. + // - The execution order has not changed. (The block executes with the same + // args and the same order with before.) + std::vector loops; + loops.reserve(loop_srefs.size()); + StmtSRef outer_loop_sref{nullptr}; + const ForNode* outer_loop = nullptr; + arith::Analyzer analyzer; + // Step 1. check correctness + for (const StmtSRef& sref : loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + if (!loop->annotations.empty() || loop->thread_binding.defined()) { + throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + } + if (outer_loop_sref.defined()) { + if (sref->parent != outer_loop_sref.get()) { + throw OuterNotInnerParent(self->mod, GetRef(outer_loop), GetRef(loop)); + } + if (!outer_loop->body.same_as(GetRef(loop))) { + throw NotOnlyChildError(self->mod, GetRef(outer_loop), GetRef(loop)); + } + } + outer_loop_sref = sref; + outer_loop = loop; + if (!analyzer.CanProve(loop->min == 0)) { + throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + } + loops.push_back(loop); + } + // Step 2. Create fused loop var and replace the original loop vars + std::string suffix; + int n = loops.size(); + for (int i = 1; i < n; i++) { + suffix += "_" + loops[i]->loop_var->name_hint; + } + suffix += "_fused"; + Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); + Array substitute_value; + substitute_value.resize(loops.size()); + PrimExpr tot = fused_var; + for (int i = static_cast(loops.size()) - 1; i >= 0; i--) { + substitute_value.Set(i, floormod(tot, loops[i]->extent)); + tot = floordiv(tot, loops[i]->extent); + } + Stmt new_stmt = loops.back()->body; + Map opaque_block_reuse; + auto f_substitute = [&](const Var& v) -> Optional { + for (int i = 0; i < n; i++) { + if (v.same_as(loops[i]->loop_var)) { + return substitute_value[i]; + } + } + return NullOpt; + }; + new_stmt = + SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt)); + // Step 3. Generate a loop to replace the original loops + PrimExpr fused_extent = 1; + for (int i = 0; i < n; i++) { + fused_extent *= loops[i]->extent; + } + fused_extent = analyzer.Simplify(fused_extent); + new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt); + new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings( + std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite()); + self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); + return self->stmt2ref.at(new_stmt.get()); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 115f7936f64e..77d17c9dc6e9 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -123,6 +123,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") .set_body_method(&ScheduleNode::GetLoops); /******** (FFI) loops manipulation ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); /******** (FFI) compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") .set_body_method(&ScheduleNode::ComputeInline); diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py new file mode 100644 index 000000000000..4c5c49a1a039 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -0,0 +1,453 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.script import ty + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i, j, k in tir.grid(128, 128, n): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic_fused(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i_j_k_fused in tir.serial(0, (n * 16384)): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(i_j_k_fused, (n * 128))) + tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, n), 128)) + tir.bind(vk, tir.floormod(i_j_k_fused, n)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic_split(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i, j, k0, k1 in tir.grid(128, 128, 10, tir.floordiv((n + 9), 10)): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + tir.where((((k0 * tir.floordiv((n + 9), 10)) + k1) < n)) + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, ((k0 * tir.floordiv((n + 9), 10)) + k1)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_seq(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + C = tir.alloc_buffer((128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "C") as [vi, vj, vk]: + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_anno(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128, annotations={"useless_annotation": True}): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_thread_binding(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.thread_binding(0, 128, thread="threadIdx.x"): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_starting_point(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(10, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j, k in tir.grid(128, 128, 128): + with tir.block([], "opaque"): + tir.reads([A[i, j, k]]) + tir.writes([B[i, j, k]]) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_fused(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for fused in tir.serial(0, 2097152): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(fused, 16384)) + tir.bind(vj, tir.floormod(tir.floordiv(fused, 128), 128)) + tir.bind(vk, tir.floormod(fused, 128)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_case0(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128, 128, 128]) + for i1, i2, i3, j1, j2, k1, k2 in tir.grid(2, 1, 64, 4, 32, 16, 8): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, ((i1 * 64) + i3)) + tir.bind(vj, ((j1 * 32) + j2)) + tir.bind(vk, ((k1 * 8) + k2)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_case1(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128, 128, 128]) + for i1, i2, i3, j1, j2, j3, k1, k2, k3 in tir.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i1 * 64 + i3) + tir.bind(vj, j1 * 64 + j3) + tir.bind(vk, k1 * 64 + k3) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_with_predicate(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + for i0, i1, i2, j0, j1, k0, k1 in tir.grid(1000, 2, 3, 1, 129, 3, 43): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.where( + ( + ((((((i0 * 2) + i1) * 3) + i2) < 128) and (((j0 * 129) + j1) < 128)) + and (((k0 * 43) + k1) < 128) + ) + ) + tir.bind(vi, (((i0 * 6) + (i1 * 3)) + i2)) + tir.bind(vj, j1) + tir.bind(vk, ((k0 * 43) + k1)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_fuse_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + for i_j_k_fused in tir.serial(0, 2097152): + with tir.block([], "opaque"): + tir.reads( + [ + A[ + tir.floormod(tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), + tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), + tir.floormod(i_j_k_fused, 128), + ] + ] + ) + tir.writes( + [ + B[ + tir.floormod(tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), + tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), + tir.floormod(i_j_k_fused, 128), + ] + ] + ) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(i_j_k_fused, 16384)) + tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, 128), 128)) + tir.bind(vk, tir.floormod(i_j_k_fused, 128)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + + for i0, i1, j, k in tir.grid(8, 16, 128, 128): + with tir.block([], "opaque"): + tir.reads([A[i0 * 16 + i1, j, k]]) + tir.writes([B[i0 * 16 + i1, j, k]]) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i0 * 16 + i1) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16], "float32") + B = tir.match_buffer(b, [16, 16], "float32") + with tir.block([16, 16], "A") as [vi, vj]: + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, vi * 16 + vj, 1) + with tir.block([16, 16], "B") as [vi, vj]: + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +@tvm.script.tir +def opaque_access_fused(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16]) + B = tir.match_buffer(b, [16, 16]) + for i_j_fused in tir.serial(0, 256): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 16)) + tir.bind(vj, tir.floormod(i_j_fused, 16)) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, ((vi * 16) + vj), 1, 1) + for i_j_fused in tir.serial(0, 256): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 16)) + tir.bind(vj, tir.floormod(i_j_fused, 16)) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate( + tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle") + ) + + +@tvm.script.tir +def opaque_access_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + B = tir.match_buffer(b, (16, 16)) + for i, j0, j1 in tir.grid(16, 4, 4): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, ((j0 * 4) + j1)) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, ((vi * 16) + vj), 1, 1) + for i, j0, j1 in tir.grid(16, 4, 4): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, ((j0 * 4) + j1)) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate( + tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle") + ) + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_fuse(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_fused, sch.mod["main"]) + + +def test_split(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[2, 1, 64]) + sch.split(j, factors=[4, 32]) + sch.split(k, factors=[16, 8]) + tvm.ir.assert_structural_equal(elementwise_split_case0, sch.mod["main"]) + + +def test_split_with_inferred_factor(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[None, 1, 64]) + sch.split(j, factors=[2, None, 64]) + sch.split(k, factors=[2, 1, None]) + tvm.ir.assert_structural_equal(elementwise_split_case1, sch.mod["main"]) + + +def test_split_with_predicate(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[1000, 2, 3]) + sch.split(j, factors=[None, 129]) + sch.split(k, factors=[3, None]) + tvm.ir.assert_structural_equal(elementwise_split_with_predicate, sch.mod["main"]) + + +def test_fuse_fail_not_only_child(): + sch = tir.Schedule(elementwise_with_seq, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + + +def test_fuse_split_fail_with_annotation(): + sch = tir.Schedule(elementwise_with_anno, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_split_fail_not_start_with_zero(): + sch = tir.Schedule(elementwise_with_anno, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_with_opaque_block(): + sch = tir.Schedule(elementwise_with_opaque_block, debug_mode=True) + block_opaque = sch.get_block("opaque") + i, j, k = sch.get_loops(block_opaque) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_fuse_with_opaque_block, sch.mod["main"]) + + +def test_fuse_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mode=True) + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + sch.fuse(i, j) + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + sch.fuse(i, j) + tvm.ir.assert_structural_equal(opaque_access_fused, sch.mod["main"]) + + +def test_split_with_opaque_block(): + sch = tir.Schedule(elementwise_with_opaque_block, debug_mode=True) + block_opaque = sch.get_block("opaque") + i, j, k = sch.get_loops(block_opaque) + sch.split(i, factors=[None, 16]) + tvm.ir.assert_structural_equal(elementwise_split_with_opaque_block, sch.mod["main"]) + + +def test_split_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mode=True) + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + sch.split(j, factors=[None, 4]) + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + sch.split(j, factors=[None, 4]) + tvm.ir.assert_structural_equal(opaque_access_split, sch.mod["main"]) + + +def test_fuse_split_fail_with_thread_binding(): + sch = tir.Schedule(elementwise_with_thread_binding, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_symbolic(): + sch = tir.Schedule(elementwise_symbolic, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_symbolic_fused, sch.mod["main"]) + + +def test_split_symbolic(): + sch = tir.Schedule(elementwise_symbolic, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(k, factors=[10, None]) + tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"]) + + +if __name__ == "__main__": + pytest.main([__file__])