From 2c1bb3d60b6d3cbf3eadd0767b3eb8f27b400b87 Mon Sep 17 00:00:00 2001 From: "yin.changsheng" Date: Fri, 24 Mar 2023 09:53:05 +0000 Subject: [PATCH] [TIR] Add merge primitive for TIR schedule --- include/tvm/tir/schedule/schedule.h | 10 + python/tvm/tir/schedule/schedule.py | 78 +++++++ src/tir/schedule/concrete_schedule.cc | 11 + src/tir/schedule/concrete_schedule.h | 1 + src/tir/schedule/primitive.h | 13 ++ .../schedule/primitive/loop_transformation.cc | 192 +++++++++++++++++ src/tir/schedule/schedule.cc | 1 + src/tir/schedule/traced_schedule.cc | 10 + src/tir/schedule/traced_schedule.h | 1 + .../unittest/test_tir_schedule_merge.py | 197 ++++++++++++++++++ 10 files changed, 514 insertions(+) create mode 100644 tests/python/unittest/test_tir_schedule_merge.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 570560c62d8c1..612e1453874d2 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -292,6 +292,16 @@ class ScheduleNode : public runtime::Object { */ virtual Array GetConsumers(const BlockRV& block_rv) = 0; /******** Schedule: Transform loops ********/ + /*! + * \brief Merge a list of loops into one. The loops under their LCA requires: + * 1) Under the same scope. + * 2) Can't have annotations or thread bindings + * 3) Start with 0 and have same domain. + * 4) The inner loop must be the only child of the outer loop. + * \param loop_rvs The loops to the loops to be merged + * \return The new loop after merge + */ + virtual LoopRV Merge(const Array& loop_rvs) = 0; /*! * \brief Fuse a list of consecutive loops into one. It requires: * 1) The loops can't have annotations or thread bindings. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 68f0b9454cb13..859a5446629b2 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -541,6 +541,84 @@ def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]: return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member ########## Schedule: Transform loops ########## + @type_checked + def merge( + self, + *loops: List[LoopRV], + ) -> LoopRV: + """Merge a list of loops into one. The loops under their LCA requires: + 1) Under the same scope + 2) Can't have annotations or thread bindings. + 3) Start with 0 and have same domain + 4) The inner loop must be the only child of the outer loop. + + Parameters + ---------- + *loops : List[LoopRV] + The loops to be merged + + Returns + ------- + fused_loop : LoopRV + The new loop after merge + + Examples + -------- + + Before applying merge, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_merge(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do fuse: + + .. code-block:: python + + sch = tir.Schedule(before_fuse) + i1, _ = sch.get_loops(sch.get_block("B")) + i2, _ = sch.get_loops(sch.get_block("C")) + sch.merge(i1, i2) + print(sch.mod["main"].script()) + + After applying fuse, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_fuse(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + # the 2 loops are merged into 1 + for i_m in range(128): + for j in range(128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i_m, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for j in range(128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i_m, j]) + T.reads(A[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = A[vi, vj] * T.float32(2) + """ + return _ffi_api.ScheduleMerge(self, loops) # type: ignore # pylint: disable=no-member + @type_checked def fuse( self, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 93ea38169d74b..5a82325586f02 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -356,6 +356,17 @@ Array ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { /******** Schedule: Transform loops ********/ +LoopRV ConcreteScheduleNode::Merge(const Array& loop_rvs) { + CHECK(!loop_rvs.empty()) << "ValueError: 'merge' requires at least 1 loop(s)"; + Array loop_srefs = this->GetSRefs(loop_rvs); + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::Merge(state_, loop_srefs); + TVM_TIR_SCHEDULE_END("merge", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs, bool preserve_unit_iters) { CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; Array loop_srefs = this->GetSRefs(loop_rvs); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 227288b232d92..74953f1270b35 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -101,6 +101,7 @@ class ConcreteScheduleNode : public ScheduleNode { Array GetConsumers(const BlockRV& block_rv) override; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) override; + LoopRV Merge(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors, bool preserve_unit_iters) override; void Reorder(const Array& ordered_loop_rvs) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 09185498e143d..0222e34558c00 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -161,6 +161,19 @@ Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sr */ TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array& factors, bool preserve_unit_iters); + +/*! + * \brief Merge a list of loops into one. The loops under their LCA requires: + * 1) Under the same scope. + * 2) Can't have annotations or thread bindings + * 3) Start with 0 and have same domain. + * 4) The inner loop must be the only child of the outer loop. + * \param self The state of the schedule + * \param loop_srefs An array of srefs to the loops to be merged + * \return The new loop after merge + */ +TVM_DLL StmtSRef Merge(ScheduleState self, const Array& loop_srefs); + /*! * \brief Fuse a list of consecutive loops into one. It requires: * 1) The loops can't have annotations or thread bindings. diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 992817e87e2da..af8ade7847795 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -451,6 +451,165 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array return result_srefs; } +class LoopReconstructor : private StmtMutator { + public: + explicit LoopReconstructor(Block scope_root, std::vector> loops) + : scope_root_(scope_root), loops_(loops) {} + + using StmtMutator::operator(); + + /*! + * \brief Create the new nest loops induced by the given loops + */ + void MakeNewLoop() { + Array new_loop_vars; + Array new_loop_extents; + Array new_stmts; + for (size_t i = 0; i < loops_.size(); i++) { + Map var_map; + for (size_t j = 0; j < loops_[i].size(); j++) { + if (i == 0) { + std::string suffix = loops_[i][j]->loop_var->name_hint; + suffix += "_m"; + int bits = loops_[i][j]->loop_var.dtype().bits(); + Var merged_var(suffix, DataType::Int(bits)); + new_loop_vars.push_back(merged_var); + new_loop_extents.push_back(loops_[i][j]->extent); + } + var_map.Set(loops_[i][j]->loop_var, new_loop_vars[j]); + } + auto new_stmt = Substitute(loops_[i][0]->body, var_map); + new_stmts.push_back(new_stmt); + this->need_remove_loop_.push_back(loops_[i].back()); + } + auto new_loop = For(new_loop_vars[0], Integer(0), new_loop_extents[0], ForKind::kSerial, + SeqStmt(std::move(new_stmts))); + this->new_inner_loop_ = new_loop; + for (size_t i = 1; i < new_loop_vars.size(); ++i) { + const Var& loop_var = new_loop_vars[i]; + const PrimExpr& loop_extent = new_loop_extents[i]; + new_loop = For(loop_var, Integer(0), loop_extent, ForKind::kSerial, new_loop); + } + this->new_outer_loop_ = new_loop; + } + + private: + Stmt VisitStmt_(const BlockNode* block) final { + if (block != scope_root_.get()) { + return GetRef(block); + } + return StmtMutator::VisitStmt_(block); + } + + Stmt VisitStmt_(const ForNode* loop) final { + if (loop == need_remove_loop_.back()) { + return new_outer_loop_; + } else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(), loop)) { + return Evaluate(0); + } + return StmtMutator::VisitStmt_(loop); + } + + Stmt VisitStmt_(const SeqStmtNode* seq_stmt) final { + auto ret = Downcast(StmtMutator::VisitSeqStmt_(seq_stmt, true)); + Array filtered; + for (Stmt stmt : ret->seq) { + if (!is_no_op(stmt)) { + filtered.push_back(std::move(stmt)); + } + } + ret = SeqStmt(filtered); + if (ret->size() == 0) { + return Evaluate(0); + } else if (ret->size() == 1) { + return ret->seq[0]; + } else { + return std::move(ret); + } + } + + public: + /*! \brief The root block of the block scope */ + Block scope_root_; + /*! \brief The given loops to be merge */ + std::vector> loops_; + /*! \brief The outermost new loop to replace the original loop */ + For new_outer_loop_{nullptr}; + /*! \brief The innermost new loop to replace the original loop */ + For new_inner_loop_{nullptr}; + /*! \brief The loops to be removed */ + std::vector need_remove_loop_; +}; + +StmtSRef Merge(ScheduleState self, const Array& loop_srefs) { + // Invariance + // - The total repeat number has not changed for each direct child block. + // - The execution order has not changed. (The block executes with the same + // args and the same order with before.) + arith::Analyzer analyzer; + StmtSRef scope_root_sref; + StmtSRef lca = GetSRefLowestCommonAncestor(loop_srefs); + std::vector> lca_nest_loops; + // Step 1. check correctness + std::vector nest_loop_loops; + std::vector nest_loop_extents; + for (size_t i = 0; i < loop_srefs.size(); i++) { + const StmtSRef& sref = loop_srefs[i]; + auto scope_root_sref_ = GetScopeRoot(self, sref, /*require_stage_pipeline=*/false); + std::vector nest_loop_i_extents; + std::vector nest_loop_i_loops; + for (auto p = sref.get(); p != lca.get(); p = p->parent) { + if (auto loop = p->StmtAs()) { + if (!loop->annotations.empty() || loop->thread_binding.defined()) { + throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + } + CheckLoopStartsWithZero(self, GetRef(p), &analyzer); + nest_loop_i_loops.push_back(loop); + nest_loop_i_extents.push_back(loop->extent); + } + } + lca_nest_loops.push_back(nest_loop_i_loops); + const ForNode* outer_loop = nullptr; + for (auto iter = nest_loop_i_loops.rbegin(); iter != nest_loop_i_loops.rend(); ++iter) { + if (outer_loop && !outer_loop->body.same_as(GetRef(*iter))) { + throw NotOnlyChildError(self->mod, GetRef(outer_loop), GetRef(*iter)); + } + outer_loop = *iter; + } + if (i == 0) { + scope_root_sref = scope_root_sref_; + nest_loop_loops = nest_loop_i_loops; + nest_loop_extents = nest_loop_i_extents; + } else { + if (scope_root_sref_.get() != scope_root_sref.get()) { + LOG(FATAL) << "ScheduleError: Expected the loops to be under the same block scope"; + throw; + } + if (nest_loop_i_extents.size() != nest_loop_extents.size()) { + LOG(FATAL) << "ScheduleError: Merge loop's nesting depth must be same, but not"; + throw; + } else { + for (size_t j = 0; j < nest_loop_i_extents.size(); j++) { + if (!analyzer.CanProveEqual(nest_loop_i_extents[j], nest_loop_extents[j])) { + LOG(FATAL) << "ScheduleError: Merge loop's `extent` must be same, but not." + << "extent=[" << j << "," << nest_loop_extents[j] << "," + << nest_loop_i_extents[j] << "]"; + throw; + } + } + } + } + } + // Step 2. Create merged loops and replace the original loops + Block scope_root = GetRef(scope_root_sref->StmtAs()); + LoopReconstructor reconstructor(scope_root, lca_nest_loops); + reconstructor.MakeNewLoop(); + Block new_scope_root = Downcast(reconstructor(scope_root)); + // Step 3. Do the actual replacement + self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); + return self->stmt2ref.at(reconstructor.new_inner_loop_.get()); +} + StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preserve_unit_iters) { // Invariance // - The total repeat number has not changed for each direct child block. @@ -795,6 +954,38 @@ struct SplitTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct MergeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Merge"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + template + static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs) { + setter(delta, inputs); + } + + static LoopRV UnpackedApplyToSchedule(Schedule sch, Array loop_rvs) { + return sch->Merge(loop_rvs); + } + + static String UnpackedAsPython(Array outputs, Array loop_rvs) { + PythonAPICall py("merge"); + for (const String& loop_rv : loop_rvs) { + py.Input("", loop_rv); + } + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + struct FuseTraits : public UnpackedInstTraits { static constexpr const char* kName = "Fuse"; static constexpr bool kIsPure = false; @@ -893,6 +1084,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits { }; TVM_REGISTER_INST_KIND_TRAITS(SplitTraits); +TVM_REGISTER_INST_KIND_TRAITS(MergeTraits); TVM_REGISTER_INST_KIND_TRAITS(FuseTraits); TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits); TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits); diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index a0e39b74d31b9..2958df1e8e91d 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -153,6 +153,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers") .set_body_method(&ScheduleNode::GetConsumers); /******** (FFI) Transform loops ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method(&ScheduleNode::Merge); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder") diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 2b6a7f71d4f5c..786dcc285798b 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -176,6 +176,16 @@ Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { /******** Schedule: Transform loops ********/ +LoopRV TracedScheduleNode::Merge(const Array& loop_rvs) { + LoopRV result = ConcreteScheduleNode::Merge(loop_rvs); + static const InstructionKind& kind = InstructionKind::Get("Merge"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rvs.begin(), loop_rvs.end()}, + /*attrs=*/{}, + /*outputs=*/{result})); + return result; +} + LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs, bool preserve_unit_loops) { LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 8b9621c749de2..688690e872ce8 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -61,6 +61,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { Array GetConsumers(const BlockRV& block_rv) final; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) final; + LoopRV Merge(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs, bool preserve_unit_iters) final; void Reorder(const Array& ordered_loop_rvs) final; diff --git a/tests/python/unittest/test_tir_schedule_merge.py b/tests/python/unittest/test_tir_schedule_merge.py new file mode 100644 index 0000000000000..9970808fdf2ab --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_merge.py @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import te, tir +from tvm.script import tir as T +from tvm.tir.expr import IntImm +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable + + +@T.prim_func +def elementwise(a: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + D = T.match_buffer(d, (64, 64)) + B = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): + with T.block("C"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + j_1) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) + for i_0, j_0, i_1, j_1 in T.grid(8, 8, 8, 8): + with T.block("D"): + vi = T.axis.spatial(64, i_0 * 8 + i_1) + vj = T.axis.spatial(64, j_0 * 8 + j_1) + T.reads(B[vi, vj]) + T.writes(D[vi, vj]) + D[vi, vj] = B[vi, vj] + T.float32(2) + + +@T.prim_func +def elementwise_different_loops_extent(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + C = T.match_buffer(c, (128, 128, 128)) + B = T.alloc_buffer((128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(1, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for i, j in T.grid(128, 128): + for k in T.serial(0, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@T.prim_func +def elementwise_with_seq(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + C = T.match_buffer(c, (128, 128, 128)) + B = T.alloc_buffer((128, 128, 128)) + D = T.alloc_buffer((128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(0, 128): + with T.block("D"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + D[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in T.serial(0, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for i, j in T.grid(128, 128): + for k in T.serial(0, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = B[vi, vj, vk] * 2.0 + + +@T.prim_func +def elementwise_merged(a: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + D = T.match_buffer(d, (64, 64)) + B = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i_0_m in range(8): + for j_0, i_1, j_1 in T.grid(8, 16, 16): + with T.block("C"): + vi = T.axis.spatial(128, i_0_m * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + j_1) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) + for j_0, i_1, j_1 in T.grid(8, 8, 8): + with T.block("D"): + vi = T.axis.spatial(64, i_0_m * 8 + i_1) + vj = T.axis.spatial(64, j_0 * 8 + j_1) + T.reads(B[vi, vj]) + T.writes(D[vi, vj]) + D[vi, vj] = B[vi, vj] + T.float32(2) + + +@T.prim_func +def elementwise_merged2(a: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + D = T.match_buffer(d, (64, 64)) + B = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i_0_m, j_0_m in T.grid(8, 8): + for i_1, j_1 in T.grid(16, 16): + with T.block("C"): + vi = T.axis.spatial(128, i_0_m * 16 + i_1) + vj = T.axis.spatial(128, j_0_m * 16 + j_1) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) + for i_1, j_1 in T.grid(8, 8): + with T.block("D"): + vi = T.axis.spatial(64, i_0_m * 8 + i_1) + vj = T.axis.spatial(64, j_0_m * 8 + j_1) + T.reads(B[vi, vj]) + T.writes(D[vi, vj]) + D[vi, vj] = B[vi, vj] + T.float32(2) + + +def test_merge(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_c = sch.get_block("C") + block_d = sch.get_block("D") + i = sch.get_loops(block_c)[0] + j = sch.get_loops(block_d)[0] + sch.merge(i, j) + tvm.ir.assert_structural_equal(elementwise_merged, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_merge2(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_c = sch.get_block("C") + block_d = sch.get_block("D") + i = sch.get_loops(block_c)[1] + j = sch.get_loops(block_d)[1] + sch.merge(i, j) + tvm.ir.assert_structural_equal(elementwise_merged2, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_merge_fail_not_only_child(): + sch = tir.Schedule(elementwise_with_seq, debug_mask="all") + block_b = sch.get_block("B") + _, _, b = sch.get_loops(block_b) + block_c = sch.get_block("C") + _, _, c = sch.get_loops(block_c) + with pytest.raises(tvm.tir.ScheduleError): + sch.merge(b, c) + + +def test_merge_fail_with_dependent_loops(): + sch = tir.Schedule(elementwise_different_loops_extent, debug_mask="all") + block_b = sch.get_block("B") + _, _, b = sch.get_loops(block_b) + block_c = sch.get_block("C") + _, _, c = sch.get_loops(block_c) + with pytest.raises(tvm.tir.ScheduleError): + sch.merge(b, c) + + +if __name__ == "__main__": + tvm.testing.main()