From ca30e5e2e4f89aa4cce318da7c333fa84a964d26 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Thu, 28 Jul 2022 01:35:54 +0800 Subject: [PATCH] TIR Schedule primitive - decompose_padding (#12174) Co-authored-by: baoxinqi --- include/tvm/tir/schedule/schedule.h | 9 + python/tvm/tir/schedule/schedule.py | 78 +++ src/tir/schedule/concrete_schedule.cc | 9 + src/tir/schedule/concrete_schedule.h | 2 + src/tir/schedule/error.h | 32 +- src/tir/schedule/primitive.h | 11 + .../schedule/primitive/decompose_padding.cc | 574 ++++++++++++++++++ src/tir/schedule/primitive/reduction.cc | 27 +- src/tir/schedule/schedule.cc | 5 + src/tir/schedule/traced_schedule.cc | 12 + src/tir/schedule/traced_schedule.h | 2 + .../test_tir_schedule_decompose_padding.py | 313 ++++++++++ 12 files changed, 1048 insertions(+), 26 deletions(-) create mode 100644 src/tir/schedule/primitive/decompose_padding.cc create mode 100644 tests/python/unittest/test_tir_schedule_decompose_padding.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 8e160c61328c..39de9de528ab 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -617,6 +617,15 @@ class ScheduleNode : public runtime::Object { BufferIndexType buffer_index_type, const Array& axis_separators) = 0; + /*! + * \brief Decompose a padding block into a block filling const pad values and a block + * writing in-bound values. + * \param block_rv The block that match the padding pattern. + * \param loop_rv The loop above which the const filling block is inserted before. + * \return The const pad value filling block. + */ + virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0; + /******** Schedule: Misc ********/ /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ virtual void EnterPostproc() = 0; diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 73bb8140e17d..7bec054b7368 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2639,6 +2639,84 @@ def after_set_axis_separators( self, block, buffer_index, buffer_index_type_enum, axis_separators ) + ########## Schedule: Padding decomposition ######### + @type_checked + def decompose_padding(self, block: Union[BlockRV, str], loop: LoopRV) -> BlockRV: + """Decompose a block of padding computation pattern into two separate blocks. + + a) The block which fill const pad values into full write region; + + b) The block which fill in-bound values into region where pad predicate is true. + + The pad value filling block is inserted right before the given loop. + + The schedule primitive requires: + + 1) The input block is a complete block. + + 2) The input loop is the ancestor of the block. + + 3) The input block is a block which match padding pattern. + + Parameters + ---------- + block : Union[BlockRV, str] + The padding block to be decomposed. + loop : LoopRV + The loop above which the pad value filling block is inserted before. + + Returns + ------- + pad_value_block : BlockRV + The block filling const pad values. + + Examples + -------- + Before decompose-padding, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_decompose(x: T.Buffer[128, "int32"], y: T.Buffer[140, "int32"]): + for i in range(140): + with T.block("block"): + vi = T.axis.remap("S", [i]) + y[vi] = T.if_then_else(vi >= 6 and vi < 134, x[vi - 6], 0, dtype="int32") + + Create the schedule and do decompose-padding with specified loop: + + .. code-block:: python + + sch = tir.Schedule(before_decompose, debug_mask="all") + block = sch.get_block("block") + sch.decompose_padding(block, sch.get_loops(block)[0]) + print(sch.mod["main].script()) + + After applying decompose-padding, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_decompose(x: T.Buffer[128, "int32"], y: T.Buffer[140, "int32"]): + for i in T.serial(140): + with T.block("block_pad_const"): + vi = T.axis.spatial(140, i) + y[vi] = 0 + for i in T.serial(128): + with T.block("block"): + vi = T.axis.spatial(128, i) + y[vi + 6] = x[vi] + """ + block = self._normalize_block_arg(block) + return _ffi_api.ScheduleDecomposePadding( # type: ignore # pylint: disable=no-member + self, block, loop + ) + + @type_checked + def can_decompose_padding(self, block: Union[BlockRV, str], loop: LoopRV) -> bool: + """Check whether the block match padding pattern and can be decomposed.""" + return _ffi_api.CanDecomposePadding(self, block, loop) # type: ignore # pylint: disable=no-member + ########## Schedule: Misc ########## @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 35f31ac9165c..9d0bb885e2c3 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -764,6 +764,15 @@ void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_ this->state_->DebugVerify(); } +BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::DecomposePadding(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv)); + TVM_TIR_SCHEDULE_END("decompose-padding", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index feea310bd7af..d4fa522492a5 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -147,6 +147,8 @@ class ConcreteScheduleNode : public ScheduleNode { void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const Array& axis_separators) override; + /******** Schedule: Padding decomposition ********/ + BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override; /******** Schedule: Misc ********/ void EnterPostproc() override {} diff --git a/src/tir/schedule/error.h b/src/tir/schedule/error.h index 46447cfbde49..e28164c6c39b 100644 --- a/src/tir/schedule/error.h +++ b/src/tir/schedule/error.h @@ -18,9 +18,11 @@ */ #ifndef TVM_TIR_SCHEDULE_ERROR_H_ #define TVM_TIR_SCHEDULE_ERROR_H_ - #include +#include +#include + namespace tvm { namespace tir { @@ -52,6 +54,34 @@ class ScheduleError : public tvm::runtime::Error { String RenderReport(const String& primitive) const; }; +class LoopPositionError : public ScheduleError { + public: + explicit LoopPositionError(IRModule mod, For loop, Block block, const std::string& primitive) + : mod_(std::move(mod)), + loop_(std::move(loop)), + block_(std::move(block)), + primitive_(primitive) {} + + String FastErrorString() const final { + return "ScheduleError: " + primitive_ + " expect the loop to be an ancestor of block"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "ScheduleError: The input loop {0} of " << primitive_ + << " is required to be be an ancestor of block {1}."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_, block_}; } + + IRModule mod_; + For loop_; + Block block_; + std::string primitive_; +}; + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 608368fbb31f..0e9f322356a2 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -482,6 +482,17 @@ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, const IndexMap& index_map); +/******** Schedule: Padding decomposition ********/ +/*! + * \brief Decompose a padding block into a block filling const pad values and a block + * writing in-bound values. + * \param block_sref The block sref that match the padding pattern. + * \param loop_sref The loop above which the const filling block is inserted before. + * \return The padding value filling block sref. + */ +TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref, + const StmtSRef& loop_sref); + /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc new file mode 100644 index 000000000000..365c6d43f127 --- /dev/null +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -0,0 +1,574 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../../transforms/ir_utils.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief Information used to create new padding block */ +struct PaddingBlockInfo { + /*! \brief In-bound block iter regions, wrt loop vars. */ + Array in_bound_region; + /*! \brief In-bound value, wrt block iter vars. */ + PrimExpr in_bound_value; + /*! \brief Condition of in-bound write, wrt loop vars. */ + PrimExpr in_bound_predicate; + /*! \brief Padding value, should be a constant. */ + PrimExpr pad_value; +}; + +class PaddingPatternMatchError : public ScheduleError { + public: + PaddingPatternMatchError(IRModule mod, Block block, const std::string& error_msg) + : mod_(std::move(mod)), block_(std::move(block)), error_msg_(error_msg) {} + + String FastErrorString() const final { + return "ScheduleError: decompose_padding expect the block to match padding pattern\n " + + error_msg_; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "ScheduleError: decompose_padding expect the block {0} to match padding pattern\n " + << error_msg_; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; + std::string error_msg_; +}; + +/*! + * \brief Helper class to analyze and check the padding pattern of the block, + * then return the padding information. + */ +class PaddingInfoAnalyzer { + public: + static PaddingBlockInfo CheckAndGetPaddingInfo(IRModule mod, const BlockRealizeNode* realize, + const Map& dom_map, + arith::Analyzer* analyzer) { + PaddingInfoAnalyzer padding_analyzer(analyzer); + if (!padding_analyzer.MatchPadding(realize, dom_map)) { + throw PaddingPatternMatchError(mod, realize->block, padding_analyzer.error_msg_); + } + return padding_analyzer.info_; + } + + private: + explicit PaddingInfoAnalyzer(arith::Analyzer* analyzer) : analyzer_(analyzer) {} + + /*! \brief Detect padding pattern and update result. */ + bool MatchPadding(const BlockRealizeNode* realize, const Map& dom_map) { + // Step 1. Check match padding computation pattern. + // A[...] = T.if_then_else(predicate, B[...], imm) + Block block = realize->block; + std::unordered_map iter_values; + for (size_t i = 0; i < realize->iter_values.size(); ++i) { + Var block_var = block->iter_vars[i]->var; + iter_values[block_var.get()] = realize->iter_values[i]; + } + const BufferStoreNode* store = block->body.as(); + if (!store) { + SetError("Block body expect a BufferStore to the write buffer"); + return false; + } + const CallNode* if_then_else = store->value.as(); + if (!if_then_else || !if_then_else->op.same_as(tir::builtin::if_then_else())) { + SetError("Value of BufferStore expect to be constrained by a padding predicate"); + return false; + } + PrimExpr pad_predicate = Substitute(if_then_else->args[0], iter_values); + PrimExpr in_bound_value = if_then_else->args[1]; + PrimExpr pad_value = if_then_else->args[2]; + if (!is_const_number(pad_value)) { + SetError("Pad value should be constant"); + return false; + } + + // Step 2. Check in-bound computation to be effectiveless. + if (SideEffect(if_then_else->args[1]) > CallEffectKind::kReadState) { + SetError("Inbound computation should not have side-effect"); + return false; + } + + // Step 3. Analyze in-bound write region. + PrimExpr in_bound_predicate = RewritePredicate(pad_predicate && realize->predicate); + Array in_bound_region = this->EstimateInBoundRegion( + /*iter_values=*/realize->iter_values, /*dom_map=*/dom_map, + /*in_bound_predicate=*/in_bound_predicate); + if (in_bound_region.empty()) { + return false; + } + + // Step 4. Update result information. + info_.in_bound_value = if_then_else->args[1]; + info_.in_bound_region = in_bound_region; + info_.in_bound_predicate = in_bound_predicate; + info_.pad_value = pad_value; + return true; + } + + /*! \brief Rewrite predicate to left recursive conjunction, drop likely annotation. */ + PrimExpr RewritePredicate(const PrimExpr& predicate) { + PrimExpr res = const_true(); + std::function update = [&res, &update](PrimExpr e) { + arith::PVar a, b; + if ((a && b).Match(e)) { + update(a.Eval()); + update(b.Eval()); + } else { + if (const CallNode* call = e.as()) { + if (call->op.same_as(builtin::likely())) { + e = call->args[0]; + } + } + res = res && e; + } + }; + update(predicate); + return analyzer_->Simplify(res); + } + + /*! \brief Return iteration region of block vars where the padding predicate evals to true. */ + Array EstimateInBoundRegion(const Array& iter_values, + const Map& dom_map, + const PrimExpr& in_bound_predicate) { + Array region; + + auto res = arith::DetectIterMap(iter_values, dom_map, in_bound_predicate, + arith::IterMapLevel::Surjective, analyzer_); + if (res->indices.empty()) { + SetError("Block iters are not independent wrt padding condition"); + return {}; + } + for (const arith::IterSumExpr& sum : res->indices) { + if (sum->args.empty()) { + region.push_back(Range::FromMinExtent(sum->base, 1)); + } else { + ICHECK_EQ(sum->args.size(), 1U); + if (!analyzer_->CanProveEqual(sum->args[0]->scale, 1)) { + SetError("Strided iteration is not supported"); + return {}; + } + region.push_back(Range::FromMinExtent(sum->base, sum->args[0]->extent)); + } + } + return region; + } + + void SetError(const std::string& msg) { error_msg_ = msg; } + + /*! \brief padding info analyse result. */ + PaddingBlockInfo info_; + /*! \brief current error message. */ + std::string error_msg_; + /*! \brief arithmetic analyzer. */ + arith::Analyzer* analyzer_; +}; + +/*! \brief Create block to fill constant pad values into full region */ +static std::pair CreateConstBlock(const BlockRealizeNode* realize, + const PaddingBlockInfo& info, + const Array& loops, + const Stmt& highest_pos_inclusive, + arith::Analyzer* analyzer) { + const Block& block = realize->block; + Array new_iter_vars; + Map repl_dict; + + // create new block itervars + for (size_t i = 0; i < block->iter_vars.size(); ++i) { + const IterVar& origin_iter = block->iter_vars[i]; + Var new_var = origin_iter->var.copy_with_suffix(""); + new_iter_vars.push_back(IterVar(origin_iter->dom, new_var, IterVarType::kDataPar)); + repl_dict.Set(origin_iter->var, new_var); + } + + // rewrite expr helper + auto rewrite_expr = [&repl_dict, analyzer](const PrimExpr& e) { + return analyzer->Simplify(Substitute(e, repl_dict)); + }; + + // create new write region + ICHECK_EQ(block->writes.size(), 1U); + BufferRegion write_region = + BufferRegion(block->writes[0]->buffer, + MutateArray(block->writes[0]->region, [rewrite_expr](const Range& r) { + return Range::FromMinExtent(rewrite_expr(r->min), rewrite_expr(r->extent)); + })); + + // create block to fill const pad values + BufferStore store = Downcast(block->body); + store.CopyOnWrite()->value = info.pad_value; + store.CopyOnWrite()->indices = MutateArray(store->indices, rewrite_expr); + Block new_block(/*iter_vars=*/new_iter_vars, /*reads=*/{}, /*writes=*/{write_region}, + /*name_hint=*/block->name_hint + "_pad_const", /*body=*/std::move(store)); + + // create new loop vars + Array new_loop_vars; + for (const For& loop : loops) { + Var new_var = loop->loop_var.copy_with_suffix(""); + new_loop_vars.push_back(new_var); + repl_dict.Set(loop->loop_var, new_var); + if (loop.same_as(highest_pos_inclusive)) { + break; + } + } + + // create new block realize node + Array new_iter_values; + for (size_t i = 0; i < realize->iter_values.size(); ++i) { + new_iter_values.push_back(rewrite_expr(realize->iter_values[i])); + } + BlockRealize new_realize(/*iter_values=*/new_iter_values, + /*predicate=*/rewrite_expr(realize->predicate), + /*block=*/new_block); + + // create new loops + Stmt nest_stmt_root = new_realize; + for (size_t i = 0; i < new_loop_vars.size(); ++i) { + For loop = loops[i]; + nest_stmt_root = + For(new_loop_vars[i], loop->min, loop->extent, ForKind::kSerial, nest_stmt_root); + } + + return {nest_stmt_root, new_realize}; +} + +/*! \brief Create block to fill in-bound region values. */ +static std::pair CreateInBoundBlock(const BlockRealizeNode* realize, + const PaddingBlockInfo& info, + + const Array& loops, + const Stmt& highest_pos_inclusive, + arith::Analyzer* analyzer) { + const Block& block = realize->block; + Array new_iter_vars; + Map repl_dict; + + // record loop ranges to be mutated + Map new_loop_ranges; + for (const For& loop : loops) { + new_loop_ranges.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + if (loop.same_as(highest_pos_inclusive)) { + break; + } + } + + // create new block iter vars and iter bindings + Array new_iter_binding; + for (size_t i = 0; i < info.in_bound_region.size(); ++i) { + // add new block itervar + const IterVar& origin_itervar = block->iter_vars[i]; + Var new_var = origin_itervar->var.copy_with_suffix(""); + Range new_range = + Range::FromMinExtent(make_const(new_var->dtype, 0), info.in_bound_region[i]->extent); + new_iter_vars.push_back(IterVar(new_range, new_var, IterVarType::kDataPar)); + repl_dict.Set(origin_itervar->var, new_var + info.in_bound_region[i]->min); + + // update new loop range + Var loop_var = GetRef(realize->iter_values[i].as()); + if (loop_var.defined() && new_loop_ranges.count(loop_var)) { + // if the block binding is the loop var with single child, mutate loop range + // instead of insert extra block predicate + new_loop_ranges.Set(loop_var, new_range); + new_iter_binding.push_back(realize->iter_values[i]); + repl_dict.Set(loop_var, loop_var + info.in_bound_region[i]->min); + analyzer->Bind(loop_var, new_range, /*allow_override=*/true); + } else { + new_iter_binding.push_back( + analyzer->Simplify(realize->iter_values[i] - info.in_bound_region[i]->min)); + } + } + + // rewrite helpers + auto rewrite_expr = [&repl_dict, analyzer](const PrimExpr& e) { + return analyzer->Simplify(Substitute(e, repl_dict)); + }; + auto rewrite_region = [rewrite_expr](const Region& region) { + return MutateArray(region, [rewrite_expr](const Range& r) { + return Range::FromMinExtent(rewrite_expr(r->min), rewrite_expr(r->extent)); + }); + }; + + // create new read/write region for in-bound accesses + Array reads, writes; + for (const BufferRegion& read : block->reads) { + reads.push_back(BufferRegion(read->buffer, rewrite_region(read->region))); + } + for (const BufferRegion& write : block->writes) { + writes.push_back(BufferRegion(write->buffer, rewrite_region(write->region))); + } + + // create new block realize node + BufferStore store = Downcast(block->body); + store.CopyOnWrite()->value = rewrite_expr(info.in_bound_value); + store.CopyOnWrite()->indices = MutateArray(store->indices, rewrite_expr); + Block new_block(/*iter_vars=*/new_iter_vars, /*reads=*/reads, /*writes=*/writes, + /*name_hint=*/block->name_hint, /*body=*/std::move(store)); + PrimExpr new_predicate = rewrite_expr(info.in_bound_predicate); + BlockRealize new_realize(/*iter_values=*/new_iter_binding, /*predicate=*/new_predicate, + /*block=*/new_block); + + // create new loops + Stmt nest_stmt_root = new_realize; + for (const For& loop : loops) { + auto it = new_loop_ranges.find(loop->loop_var); + PrimExpr min = it == new_loop_ranges.end() ? loop->min : (*it).second->min; + PrimExpr extent = it == new_loop_ranges.end() ? loop->extent : (*it).second->extent; + nest_stmt_root = For(loop->loop_var, min, extent, loop->kind, nest_stmt_root, + loop->thread_binding, loop->annotations, loop->span); + if (loop.same_as(highest_pos_inclusive)) { + break; + } + } + return {nest_stmt_root, new_realize}; +} + +/*! + * \brief A helper class to create a new scope that contains decomposed padding blocks. + */ +class DecomposePaddingBlockReplacer : public StmtMutator { + public: + /*! \brief Replacement information */ + struct ReplaceDesc { + /*! \brief loop above which to insert const pad value filling code. */ + For const_filling_pos; + /*! \brief loop under which to insert in bound value filling code. */ + For in_bound_filling_pos; + /*! \brief const pad value filling loop. */ + Stmt const_filling_loop; + /*! \brief highest in bound value filling loop with single child. */ + Stmt in_bound_filling_loop; + /*! \brief const pad value filling block. */ + BlockRealize const_filling_block; + /*! \brief in bound value filling block. */ + BlockRealize in_bound_filling_block; + }; + + static Block Replace(Block scope_root, const ReplaceDesc& desc) { + DecomposePaddingBlockReplacer replacer(desc); + return Downcast(replacer(std::move(scope_root))); + } + + private: + explicit DecomposePaddingBlockReplacer(const ReplaceDesc& desc) : desc_(desc) {} + + Stmt VisitStmt_(const ForNode* op) final { + Stmt new_loop; + if (op == desc_.in_bound_filling_pos.get()) { + // position to rewrite inbound filling code + new_loop = desc_.in_bound_filling_loop; + } else { + new_loop = StmtMutator::VisitStmt_(op); + } + if (op == desc_.const_filling_pos.get()) { + // position to insert pad value filling code + return std::move(SeqStmt({desc_.const_filling_loop, new_loop})); + } + return std::move(new_loop); + } + + Stmt VisitStmt_(const SeqStmtNode* seq) final { + Array new_stmts; + new_stmts.reserve(seq->seq.size()); + for (const Stmt& old_stmt : seq->seq) { + new_stmts.push_back(VisitStmt(old_stmt)); + } + return SeqStmt::Flatten(new_stmts); + } + + private: + const ReplaceDesc& desc_; +}; + +StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, + const StmtSRef& loop_sref, bool check_only) { + /*! + * Check + * - the block is a compact block + * - the loop is an ancester of the block + * - the block match padding pattern + * Mutate + * - generate new block to fill padding values + * - trim original block to write non-padding part only + */ + // Condition Checks and Information Collection + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); + Map dom_map; + arith::Analyzer analyzer; + + // Check 1. check the block is complete. + StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); + CheckCompleteBlock(self, block_sref, scope_root_sref); + + // Check 2. Check loop_sref is an ancestor of block_sref. Also collect + // - the highest loop position (inclusive) to insert const pad value filling code above. + // - the highest loop position (inclusive) to replace with in-bound value filling code. + Array loop_srefs = GetLoops(block_sref); + Array loops; + bool found_const_filling_pos = false; + bool found_in_bound_filling_pos = false; + For const_filling_pos = GetRef(loop_sref->StmtAs()); + For in_bound_filling_pos{nullptr}; + for (auto it = loop_srefs.rbegin(); it != loop_srefs.rend(); ++it) { + For cur_loop = GetRef((*it)->StmtAs()); + Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent); + dom_map.Set(cur_loop->loop_var, range); + analyzer.Bind(cur_loop->loop_var, range); + loops.push_back(cur_loop); + + if (!found_const_filling_pos) { + if (cur_loop.same_as(const_filling_pos)) { + found_const_filling_pos = true; + } + } + + if (!found_in_bound_filling_pos) { + if (!cur_loop->body->IsInstance() && + !cur_loop->body->IsInstance()) { + found_in_bound_filling_pos = true; + } else { + in_bound_filling_pos = cur_loop; + } + } + } + ICHECK(in_bound_filling_pos.defined()); + if (!found_const_filling_pos) { + throw LoopPositionError(self->mod, const_filling_pos, GetRef(block), + "decompose_padding"); + } + + // Check 3. match padding pattern and return padding operation info. + PaddingBlockInfo info = + PaddingInfoAnalyzer::CheckAndGetPaddingInfo(self->mod, realize, dom_map, &analyzer); + + // IR Manipulation + // Step 1. Create const pad value filling part and in-bound value filling part. + DecomposePaddingBlockReplacer::ReplaceDesc replace_desc; + replace_desc.const_filling_pos = const_filling_pos; + replace_desc.in_bound_filling_pos = in_bound_filling_pos; + std::tie(replace_desc.const_filling_loop, replace_desc.const_filling_block) = + CreateConstBlock(realize, info, loops, const_filling_pos, &analyzer); + std::tie(replace_desc.in_bound_filling_loop, replace_desc.in_bound_filling_block) = + CreateInBoundBlock(realize, info, loops, in_bound_filling_pos, &analyzer); + + // Step 2. Execute IR replacement. + Block old_scope_root_block = GetRef(scope_root_sref->StmtAs()); + Block new_scope_root = DecomposePaddingBlockReplacer::Replace(old_scope_root_block, replace_desc); + if (check_only) { + return block_sref; + } + + // Step 3. Update schedule states. + self->Replace(scope_root_sref, new_scope_root, + {{old_scope_root_block, new_scope_root}, + {GetRef(block), replace_desc.in_bound_filling_block->block}}); + auto new_block_sref = self->stmt2ref.at(replace_desc.const_filling_block->block.get()); + + // Set block info of created const pad value filling block + BlockInfo& block_info = self->block_info[new_block_sref]; + block_info.affine_binding = true; + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + + // If the const pad value filling block is lifted out of the original subtree, + // set the region_cover flag as false since region_cover is the property under the subtree. + bool preserve_stage_pipeline = true; + for (const StmtSRef& consumer_sref : GetConsumers(self, block_sref)) { + StmtSRef lca = GetSRefLowestCommonAncestor({consumer_sref, block_sref}); + const StmtSRefNode* parent = new_block_sref->parent; + bool is_under_lca = false; + while (parent) { + if (parent == lca.get()) { + is_under_lca = true; + break; + } + parent = parent->parent; + } + if (!is_under_lca) { + preserve_stage_pipeline = false; + self->block_info[consumer_sref].region_cover = false; + } + } + if (!preserve_stage_pipeline) { + self->block_info[scope_root_sref].scope->stage_pipeline = false; + } + return new_block_sref; +} + +StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref, + const StmtSRef& loop_sref) { + return DecomposePaddingImpl(self, block_sref, loop_sref, false); +} + +bool CanDecomposePadding(ScheduleState self, const StmtSRef& block_sref, + const StmtSRef& loop_sref) { + try { + DecomposePaddingImpl(self, block_sref, loop_sref, true); + } catch (const tvm::runtime::Error& e) { + return false; + } + return true; +} + +/******** FFI ********/ + +TVM_REGISTER_GLOBAL("tir.schedule.CanDecomposePadding") + .set_body_typed([](Schedule self, BlockRV block_rv, LoopRV loop_rv) { + return CanDecomposePadding(self->state(), self->GetSRef(block_rv), self->GetSRef(loop_rv)); + }); + +/******** InstructionKind Registration ********/ + +struct DecomposPaddingTraits : public UnpackedInstTraits { + static constexpr const char* kName = "DecomposePadding"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv) { + return sch->DecomposePadding(block_rv, loop_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv, LoopRV loop_rv) { + PythonAPICall py("decompose_padding"); + py.Input("block", block_rv); + py.Input("loop", loop_rv); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(DecomposPaddingTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 99ca03b6c94a..ad9043e4f2db 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -102,30 +102,6 @@ class DecomposeReductionBlockReplacer : public StmtMutator { Block new_reduction_block_; }; -class LoopPositionError : public ScheduleError { - public: - explicit LoopPositionError(IRModule mod, For loop, Block block) - : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {} - - String FastErrorString() const final { - return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block"; - } - - String DetailRenderTemplate() const final { - std::ostringstream os; - os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an " - "ancestor of block {1}."; - return os.str(); - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_, block_}; } - - IRModule mod_; - For loop_; - Block block_; -}; - class LoopHeightError : public ScheduleError { public: static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block, @@ -214,7 +190,8 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); // Cond 0. Check loop_sref is an ancestor of block_sref if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) { - throw LoopPositionError(self->mod, GetRef(loop), GetRef(block)); + throw LoopPositionError(self->mod, GetRef(loop), GetRef(block), + "decompose_reduction"); } // Cond 1. Check block is reduction StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index e386061ebfbd..091db344aadb 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -260,6 +260,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator") return self->SetAxisSeparator( block_rv, buffer_index, static_cast(buffer_index_type), axis_separators); }); + +/******** (FFI) Padding decomposition ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposePadding") + .set_body_method(&ScheduleNode::DecomposePadding); + /******** (FFI) Misc ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 93e4c984a41b..3ca603acf8f5 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -517,6 +517,18 @@ void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_in /*outputs=*/{})); } +/******** Schedule: Padding decomposition ********/ +BlockRV TracedScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) { + BlockRV new_block = ConcreteScheduleNode::DecomposePadding(block_rv, loop_rv); + static const InstructionKind& kind = InstructionKind::Get("DecomposePadding"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv, loop_rv}, + /*attrs=*/{}, + /*outputs=*/{new_block})); + return new_block; +} + /******** Schedule: Misc ********/ void TracedScheduleNode::EnterPostproc() { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index f6405d77a195..13848f12d2d8 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -107,6 +107,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const Array& axis_separators) final; + /******** Schedule: Padding decomposition ********/ + BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) final; /******** Schedule: Misc ********/ void EnterPostproc() final; }; diff --git a/tests/python/unittest/test_tir_schedule_decompose_padding.py b/tests/python/unittest/test_tir_schedule_decompose_padding.py new file mode 100644 index 000000000000..a3fc4326a3c9 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_decompose_padding.py @@ -0,0 +1,313 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import numpy as np +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T + +# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + + +def check_decompose_padding(origin, scheduled, expected, check_run=False): + tvm.ir.assert_structural_equal(scheduled, expected) + if check_run: + in_buffer = origin.buffer_map[origin.params[0]] + out_buffer = origin.buffer_map[origin.params[1]] + in_shape = [int(_) for _ in in_buffer.shape] + out_shape = [int(_) for _ in out_buffer.shape] + x = tvm.nd.array(np.random.uniform(0, 64, in_shape).astype(in_buffer.dtype)) + y0 = tvm.nd.array(np.zeros(out_shape).astype(out_buffer.dtype)) + y1 = tvm.nd.array(np.zeros(out_shape).astype(out_buffer.dtype)) + f_origin = tvm.build(origin) + f_scheduled = tvm.build(scheduled) + f_origin(x, y0) + f_scheduled(x, y1) + tvm.testing.assert_allclose(y0.numpy(), y1.numpy()) + + +def test_1d_decompose_padding(): + @T.prim_func + def before_decompose(x: T.Buffer[128, "int32"], y: T.Buffer[140, "int32"]): + for i in range(140): + with T.block("block"): + vi = T.axis.remap("S", [i]) + y[vi] = T.if_then_else(vi >= 6 and vi < 134, x[vi - 6], 0, dtype="int32") + + @T.prim_func + def after_decompose(x: T.Buffer[128, "int32"], y: T.Buffer[140, "int32"]): + for i in T.serial(140): + with T.block("block_pad_const"): + vi = T.axis.spatial(140, i) + T.reads() + T.writes(y[vi]) + y[vi] = 0 + for i in T.serial(128): + with T.block("block"): + vi = T.axis.spatial(128, i) + T.reads(x[vi]) + T.writes(y[vi + 6]) + y[vi + 6] = x[vi] + + sch = tir.Schedule(before_decompose, debug_mask="all") + block = sch.get_block("block") + sch.decompose_padding(block, sch.get_loops(block)[0]) + check_decompose_padding(before_decompose, sch.mod["main"], after_decompose, check_run=False) + + +@T.prim_func +def sum_pool_2d( + x: T.Buffer[(1, 16, 225, 225), "int8"], tensor: T.Buffer[(1, 16, 225, 225), "int8"] +): + pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") + for i0, i1, i2, i3 in T.grid(1, 16, 231, 231): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else( + 3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228, + x[ax0, ax1, ax2 - 3, ax3 - 3], + T.int8(0), + dtype="int8", + ) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 16, 225, 225, 7, 7): + with T.block("tensor"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + with T.init(): + tensor[ax0, ax1, ax2, ax3] = T.int8(0) + tensor[ax0, ax1, ax2, ax3] = ( + tensor[ax0, ax1, ax2, ax3] + pad_temp[ax0, ax1, ax2 + rv0, ax3 + rv1] + ) + + +def test_decompose_hw_padding_direct(): + """Case 0. direct decompose""" + + @T.prim_func + def pooling_decompose_0( + x: T.Buffer[(1, 16, 225, 225), "int8"], tensor: T.Buffer[(1, 16, 225, 225), "int8"] + ): + pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") + for i0, i1, i2, i3 in T.grid(1, 16, 231, 231): + with T.block("pad_temp_pad_const"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + pad_temp[ax0, ax1, ax2, ax3] = T.int8(0) + for i0, i1, i2, i3 in T.grid(1, 16, 225, 225): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + pad_temp[ax0, ax1, ax2 + 3, ax3 + 3] = x[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3, i4, i5 in T.grid(1, 16, 225, 225, 7, 7): + with T.block("tensor"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + with T.init(): + tensor[ax0, ax1, ax2, ax3] = T.int8(0) + tensor[ax0, ax1, ax2, ax3] = ( + tensor[ax0, ax1, ax2, ax3] + pad_temp[ax0, ax1, ax2 + rv0, ax3 + rv1] + ) + + sch = tir.Schedule(sum_pool_2d, debug_mask="all") + pad = sch.get_block("pad_temp") + sch.decompose_padding(pad, sch.get_loops(pad)[0]) + check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_0, check_run=True) + + +def test_decompose_hw_padding_tiled(): + """Case 1. tiling and then decompose""" + + @T.prim_func + def pooling_decompose_1( + x: T.Buffer[(1, 16, 225, 225), "int8"], tensor: T.Buffer[(1, 16, 225, 225), "int8"] + ) -> None: + pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") + for i0, i2_0, i3_0 in T.grid(1, 3, 3): + for ax0, ax1, ax2 in T.grid(16, 81, 81): + with T.block("pad_temp_pad_const"): + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(16, ax0) + ax2_1 = T.axis.spatial(231, i2_0 * 75 + ax1) + ax3 = T.axis.spatial(231, i3_0 * 75 + ax2) + T.reads() + T.writes(pad_temp[ax0_1, ax1_1, ax2_1, ax3]) + pad_temp[ax0_1, ax1_1, ax2_1, ax3] = T.int8(0) + for ax0, ax1, ax2 in T.grid(16, 81, 81): + with T.block("pad_temp"): + ax0_2 = T.axis.spatial(1, 0) + ax1_2 = T.axis.spatial(16, ax0) + ax2_2 = T.axis.spatial(225, i2_0 * 75 + ax1 - 3) + ax3 = T.axis.spatial(225, i3_0 * 75 + ax2 - 3) + T.where( + 3 <= i2_0 * 75 + ax1 + and i2_0 * 75 + ax1 < 228 + and 3 <= i3_0 * 75 + ax2 + and i3_0 * 75 + ax2 < 228 + ) + T.reads(x[ax0_2, ax1_2, ax2_2, ax3]) + T.writes(pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3]) + pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3] = x[ax0_2, ax1_2, ax2_2, ax3] + for i1, i2_1, i3_1, i4, i5 in T.grid(16, 75, 75, 7, 7): + with T.block("tensor"): + ax0_3, ax1_3 = T.axis.remap("SS", [i0, i1]) + ax2_3 = T.axis.spatial(225, i2_0 * 75 + i2_1) + ax3 = T.axis.spatial(225, i3_0 * 75 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1]) + T.writes(tensor[ax0_3, ax1_3, ax2_3, ax3]) + with T.init(): + tensor[ax0_3, ax1_3, ax2_3, ax3] = T.int8(0) + tensor[ax0_3, ax1_3, ax2_3, ax3] = ( + tensor[ax0_3, ax1_3, ax2_3, ax3] + + pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1] + ) + + sch = tir.Schedule(sum_pool_2d, debug_mask="all") + block = sch.get_block("tensor") + pad = sch.get_block("pad_temp") + n, c, h, w, kh, kw = sch.get_loops(block) + ho, hi = sch.split(h, [3, 75]) + wo, wi = sch.split(w, [3, 75]) + sch.reorder(n, ho, wo, c, hi, wi, kh, kw) + sch.compute_at(sch.get_block("pad_temp"), wo) + sch.decompose_padding(pad, sch.get_loops(pad)[3]) + check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_1, check_run=True) + + +def test_decompose_hw_padding_tiled_and_lift_pad(): + """Case 2. tiling and then decompose, lift const pad values to outer loop""" + + @T.prim_func + def pooling_decompose_2( + x: T.Buffer[(1, 16, 225, 225), "int8"], tensor: T.Buffer[(1, 16, 225, 225), "int8"] + ) -> None: + pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") + for i0, i2_0, i3_0, ax0, ax1, ax2 in T.grid(1, 3, 3, 16, 81, 81): + with T.block("pad_temp_pad_const"): + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(16, ax0) + ax2_1 = T.axis.spatial(231, i2_0 * 75 + ax1) + ax3 = T.axis.spatial(231, i3_0 * 75 + ax2) + T.reads() + T.writes(pad_temp[ax0_1, ax1_1, ax2_1, ax3]) + pad_temp[ax0_1, ax1_1, ax2_1, ax3] = T.int8(0) + for i0, i2_0, i3_0 in T.grid(1, 3, 3): + for ax0, ax1, ax2 in T.grid(16, 81, 81): + with T.block("pad_temp"): + ax0_2 = T.axis.spatial(1, 0) + ax1_2 = T.axis.spatial(16, ax0) + ax2_2 = T.axis.spatial(225, i2_0 * 75 + ax1 - 3) + ax3 = T.axis.spatial(225, i3_0 * 75 + ax2 - 3) + T.where( + 3 <= i2_0 * 75 + ax1 + and i2_0 * 75 + ax1 < 228 + and 3 <= i3_0 * 75 + ax2 + and i3_0 * 75 + ax2 < 228 + ) + T.reads(x[ax0_2, ax1_2, ax2_2, ax3]) + T.writes(pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3]) + pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3] = x[ax0_2, ax1_2, ax2_2, ax3] + for i1, i2_1, i3_1, i4, i5 in T.grid(16, 75, 75, 7, 7): + with T.block("tensor"): + ax0_3, ax1_3 = T.axis.remap("SS", [i0, i1]) + ax2_3 = T.axis.spatial(225, i2_0 * 75 + i2_1) + ax3 = T.axis.spatial(225, i3_0 * 75 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1]) + T.writes(tensor[ax0_3, ax1_3, ax2_3, ax3]) + with T.init(): + tensor[ax0_3, ax1_3, ax2_3, ax3] = T.int8(0) + tensor[ax0_3, ax1_3, ax2_3, ax3] = ( + tensor[ax0_3, ax1_3, ax2_3, ax3] + + pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1] + ) + + sch = tir.Schedule(sum_pool_2d, debug_mask="all") + block = sch.get_block("tensor") + pad = sch.get_block("pad_temp") + n, c, h, w, kh, kw = sch.get_loops(block) + ho, hi = sch.split(h, [3, 75]) + wo, wi = sch.split(w, [3, 75]) + sch.reorder(n, ho, wo, c, hi, wi, kh, kw) + sch.compute_at(sch.get_block("pad_temp"), wo) + sch.decompose_padding(pad, sch.get_loops(pad)[0]) + check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_2, check_run=True) + + +def test_decompose_hw_padding_non_perfect_tiled(): + """Case 3. non-perfect tiling and then decompose""" + + @T.prim_func + def pooling_decompose_3( + x: T.Buffer[(1, 16, 225, 225), "int8"], tensor: T.Buffer[(1, 16, 225, 225), "int8"] + ) -> None: + pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") + for i0, i2_0, i3_0 in T.grid(1, 3, 3): + for ax0, ax1, ax2 in T.grid(16, 86, 86): + with T.block("pad_temp_pad_const"): + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(16, ax0) + ax2_1 = T.axis.spatial(231, i2_0 * 80 + ax1) + ax3 = T.axis.spatial(231, i3_0 * 80 + ax2) + T.where(i2_0 * 80 + ax1 < 231 and i3_0 * 80 + ax2 < 231) + T.reads() + T.writes(pad_temp[ax0_1, ax1_1, ax2_1, ax3]) + pad_temp[ax0_1, ax1_1, ax2_1, ax3] = T.int8(0) + for ax0, ax1, ax2 in T.grid(16, 86, 86): + with T.block("pad_temp"): + ax0_2 = T.axis.spatial(1, 0) + ax1_2 = T.axis.spatial(16, ax0) + ax2_2 = T.axis.spatial(225, i2_0 * 80 + ax1 - 3) + ax3 = T.axis.spatial(225, i3_0 * 80 + ax2 - 3) + T.where( + 3 <= i2_0 * 80 + ax1 + and i2_0 * 80 + ax1 < 228 + and 3 <= i3_0 * 80 + ax2 + and i3_0 * 80 + ax2 < 228 + and i2_0 * 80 + ax1 < 231 + and i3_0 * 80 + ax2 < 231 + ) + T.reads(x[ax0_2, ax1_2, ax2_2, ax3]) + T.writes(pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3]) + pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3] = x[ax0_2, ax1_2, ax2_2, ax3] + for i1, i2_1, i3_1, i4, i5 in T.grid(16, 80, 80, 7, 7): + with T.block("tensor"): + ax0_3, ax1_3 = T.axis.remap("SS", [i0, i1]) + ax2_3 = T.axis.spatial(225, i2_0 * 80 + i2_1) + ax3 = T.axis.spatial(225, i3_0 * 80 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.where(i2_0 * 80 + i2_1 < 225 and i3_0 * 80 + i3_1 < 225) + T.reads(pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1]) + T.writes(tensor[ax0_3, ax1_3, ax2_3, ax3]) + with T.init(): + tensor[ax0_3, ax1_3, ax2_3, ax3] = T.int8(0) + tensor[ax0_3, ax1_3, ax2_3, ax3] = ( + tensor[ax0_3, ax1_3, ax2_3, ax3] + + pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1] + ) + + sch = tir.Schedule(sum_pool_2d, debug_mask="all") + block = sch.get_block("tensor") + pad = sch.get_block("pad_temp") + n, c, h, w, kh, kw = sch.get_loops(block) + ho, hi = sch.split(h, [None, 80]) + wo, wi = sch.split(w, [None, 80]) + sch.reorder(n, ho, wo, c, hi, wi, kh, kw) + sch.compute_at(sch.get_block("pad_temp"), wo) + sch.decompose_padding(pad, sch.get_loops(pad)[3]) + check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_3, check_run=True) + + +if __name__ == "__main__": + tvm.testing.main()